from dataclasses import dataclass import csv from typing import Callable import similarity as sim @dataclass class Entity: problem: dict[str] solution: dict[str] @classmethod def from_dict(cls, problem: dict, solution: dict = None): return cls(problem, solution) if solution else cls(problem, dict()) class Query(Entity): @classmethod def from_problems(cls, **problem): """ Returns a query object by providing problems as keyword arguments """ return cls(problem, solution=dict()) class Case(Entity): pass @dataclass class RetrievedCase(Case): similarity: float sim_per_field: list[tuple] def __str__(self) -> str: return " ".join(self.solution.values()).capitalize() @dataclass class CaseBase: """ A CaseBase object represents a collection of all known cases """ cases: list[Case] config: dict[str] fields: tuple[str] __field_infos: dict[str] = None def __repr__(self) -> str: key_vals = f"cases={len(self.cases)}, fields={list(self.fields.values())}" return f"{self.__class__.__name__}({key_vals})" def _default_cfg(new_cfg: dict) -> dict: """ apply default configuration if not overwritten """ default_config = { "encoding": "utf-8", "delimiter": ",", "set_int": False }.items() return {k: v if k not in new_cfg else new_cfg[k] for k, v in default_config} def _loader(reader_obj: list[dict], set_int: bool, **kwargs): """ helper function for reading. Only use internally! """ _cases = list() for elem in reader_obj: _problem, _solution = dict(), dict() for key, value in elem.items(): if key in kwargs["problem_fields"]: if set_int: try: _problem[key] = float(value) except ValueError: try: _problem[key] = float(value.replace(",", ".")) except ValueError: _problem[key] = value elif key in kwargs["solution_fields"]: if set_int: try: _solution[key] = float(value) except ValueError: try: _solution[key] = float(value.replace(",", ".")) except ValueError: _solution[key] = value _cases.append( Case.from_dict(_problem, _solution) ) return _cases @classmethod def from_csv(cls, path: str, problem_fields: list, solution_fields: list, **cfg): """ read a csv file and load every column by `problem_fields` and `solution_fields` Args: path: path to a valid .csv file problem_fields: list with columns to be considered as problem solution_fields: list with columns to be considered as solution **cfg: overwrite default configuration (see CaseBase._default_cfg()) Returns: A CaseBase object with: - a list of cases - the used configuration - a tuple with [0] -> problem_fields [1] -> solution_fields Raises: ValueError: if the passed path isnt a .csv file """ if not path.endswith(".csv"): raise ValueError("invalid file format:", path) cfg = cls._default_cfg(cfg) with open(path, encoding=cfg["encoding"]) as file: cases = cls._loader( reader_obj=csv.DictReader(file, delimiter=cfg["delimiter"]), set_int=cfg["set_int"], problem_fields=problem_fields, solution_fields=solution_fields ) return cls( cases=cases, config=cfg, fields={ "problem": problem_fields, "solution": solution_fields } ) def retrieve(self, query: Query, **fields_and_sim_funcs: dict[str, Callable]) -> RetrievedCase: """ Search for case most similar to query """ r = {"case": None, "sim": -1.0, "sim_per_field": dict()} for case in self.cases: _sim = 0.0 _sim_per_field = dict() for field, sim_func in fields_and_sim_funcs.items(): # Some columns contain special chars like spaces or brackets: # radius_curve -> radius_curve(m) # speed_limit -> speed_limit(km/h) field_name = "" if field == "radius_curve": field_name = "radius_curve(m)" elif field == "speed_limit": field_name = "speed_limit(km/h)" else: field_name = field if sim_func == sim.symbolic_sim: field_sim = ( field_name, sim_func( query.problem[field], case.problem[field_name], self.get_symbolic_sim(field) ) ) elif sim_func == sim.euclid_sim: field_sim = ( field_name, sim_func( query.problem[field], case.problem[field_name] ) ) _sim_per_field[field_sim[0]] = field_sim[1] _sim += field_sim[1] if _sim > r["sim"]: r = { "case": case, "sim": _sim, "sim_per_field": _sim_per_field } return RetrievedCase( r["case"].problem, r["case"].solution, r["sim"], r["sim_per_field"] ) def add_symbolic_sim(self, field: str, similarity_matrix: dict): """ Add similarity matrix for field """ if field not in list(self.fields["problem"]) + list(self.fields["solution"]): raise ValueError(f"unknown field {field}") if self.__field_infos is None: self.__field_infos = {} self.__field_infos[field] = { "symbolic_sims": similarity_matrix } def get_symbolic_sim(self, field: str) -> dict[str]: """ Get similarity matrix by field """ return self.__field_infos[field]["symbolic_sims"]