2022-12-13 13:25:31 +00:00
|
|
|
from dataclasses import dataclass
|
2023-01-11 06:58:45 +00:00
|
|
|
import csv
|
2022-12-13 13:25:31 +00:00
|
|
|
from typing import Callable
|
|
|
|
import similarity as sim
|
|
|
|
|
2023-01-11 06:58:45 +00:00
|
|
|
|
2022-12-13 13:25:31 +00:00
|
|
|
@dataclass
|
|
|
|
class Entity:
|
|
|
|
problem: dict[str]
|
|
|
|
solution: dict[str]
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_dict(cls, problem: dict, solution: dict = None):
|
|
|
|
return cls(problem, solution) if solution else cls(problem, dict())
|
|
|
|
|
|
|
|
|
|
|
|
class Query(Entity):
|
|
|
|
@classmethod
|
|
|
|
def from_problems(cls, **problem):
|
2023-01-11 06:58:45 +00:00
|
|
|
"""
|
|
|
|
Returns a query object by providing problems as keyword arguments
|
|
|
|
"""
|
2022-12-13 13:25:31 +00:00
|
|
|
return cls(problem, solution=dict())
|
|
|
|
|
|
|
|
|
|
|
|
class Case(Entity):
|
|
|
|
pass
|
|
|
|
|
2023-01-11 06:58:45 +00:00
|
|
|
|
2022-12-13 13:25:31 +00:00
|
|
|
@dataclass
|
|
|
|
class RetrievedCase(Case):
|
|
|
|
similarity: float
|
|
|
|
sim_per_field: list[tuple]
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
return " ".join(self.solution.values()).capitalize()
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class CaseBase:
|
2023-01-11 06:58:45 +00:00
|
|
|
"""
|
|
|
|
A CaseBase object represents a collection of all known cases
|
|
|
|
"""
|
2022-12-13 13:25:31 +00:00
|
|
|
cases: list[Case]
|
|
|
|
config: dict[str]
|
|
|
|
fields: tuple[str]
|
|
|
|
__field_infos: dict[str] = None
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
key_vals = f"cases={len(self.cases)}, fields={list(self.fields.values())}"
|
|
|
|
return f"{self.__class__.__name__}({key_vals})"
|
|
|
|
|
|
|
|
def _default_cfg(new_cfg: dict) -> dict:
|
2023-01-11 06:58:45 +00:00
|
|
|
"""
|
|
|
|
apply default configuration if not overwritten
|
|
|
|
"""
|
2022-12-13 13:25:31 +00:00
|
|
|
default_config = {
|
|
|
|
"encoding": "utf-8",
|
|
|
|
"delimiter": ",",
|
|
|
|
"set_int": False
|
|
|
|
}.items()
|
|
|
|
|
|
|
|
return {k: v if k not in new_cfg else new_cfg[k] for k, v in default_config}
|
|
|
|
|
|
|
|
def _loader(reader_obj: list[dict], set_int: bool, **kwargs):
|
2023-01-11 06:58:45 +00:00
|
|
|
"""
|
|
|
|
helper function for reading. Only use internally!
|
|
|
|
"""
|
2022-12-13 13:25:31 +00:00
|
|
|
_cases = list()
|
|
|
|
for elem in reader_obj:
|
|
|
|
_problem, _solution = dict(), dict()
|
|
|
|
for key, value in elem.items():
|
|
|
|
if key in kwargs["problem_fields"]:
|
|
|
|
if set_int:
|
|
|
|
try:
|
|
|
|
_problem[key] = float(value)
|
|
|
|
except ValueError:
|
|
|
|
try:
|
|
|
|
_problem[key] = float(value.replace(",", "."))
|
|
|
|
|
|
|
|
except ValueError:
|
|
|
|
_problem[key] = value
|
|
|
|
|
|
|
|
elif key in kwargs["solution_fields"]:
|
|
|
|
if set_int:
|
|
|
|
try:
|
|
|
|
_solution[key] = float(value)
|
|
|
|
except ValueError:
|
|
|
|
try:
|
|
|
|
_solution[key] = float(value.replace(",", "."))
|
|
|
|
except ValueError:
|
2023-01-11 06:58:45 +00:00
|
|
|
_solution[key] = value
|
2022-12-13 13:25:31 +00:00
|
|
|
_cases.append(
|
|
|
|
Case.from_dict(_problem, _solution)
|
|
|
|
)
|
2023-01-11 06:58:45 +00:00
|
|
|
|
2022-12-13 13:25:31 +00:00
|
|
|
return _cases
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_csv(cls, path: str, problem_fields: list, solution_fields: list, **cfg):
|
|
|
|
"""
|
|
|
|
read a csv file and load every column by `problem_fields` and `solution_fields`
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path: path to a valid .csv file
|
|
|
|
problem_fields: list with columns to be considered as problem
|
|
|
|
solution_fields: list with columns to be considered as solution
|
|
|
|
**cfg: overwrite default configuration (see CaseBase._default_cfg())
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A CaseBase object with:
|
|
|
|
- a list of cases
|
|
|
|
- the used configuration
|
|
|
|
- a tuple with
|
|
|
|
[0] -> problem_fields
|
|
|
|
[1] -> solution_fields
|
2023-01-11 06:58:45 +00:00
|
|
|
|
2022-12-13 13:25:31 +00:00
|
|
|
Raises:
|
|
|
|
ValueError: if the passed path isnt a .csv file
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not path.endswith(".csv"):
|
|
|
|
raise ValueError("invalid file format:", path)
|
|
|
|
|
|
|
|
cfg = cls._default_cfg(cfg)
|
|
|
|
|
|
|
|
with open(path, encoding=cfg["encoding"]) as file:
|
|
|
|
cases = cls._loader(
|
2023-01-11 06:58:45 +00:00
|
|
|
reader_obj=csv.DictReader(file, delimiter=cfg["delimiter"]),
|
|
|
|
set_int=cfg["set_int"],
|
|
|
|
problem_fields=problem_fields,
|
|
|
|
solution_fields=solution_fields
|
2022-12-13 13:25:31 +00:00
|
|
|
)
|
2023-01-11 06:58:45 +00:00
|
|
|
|
2022-12-13 13:25:31 +00:00
|
|
|
return cls(
|
2023-01-11 06:58:45 +00:00
|
|
|
cases=cases,
|
|
|
|
config=cfg,
|
|
|
|
fields={
|
2022-12-13 13:25:31 +00:00
|
|
|
"problem": problem_fields,
|
|
|
|
"solution": solution_fields
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2023-01-11 06:58:45 +00:00
|
|
|
def retrieve(self, query: Query, **fields_and_sim_funcs: dict[str, Callable]) -> RetrievedCase:
|
2022-12-13 13:25:31 +00:00
|
|
|
"""
|
2023-01-11 06:58:45 +00:00
|
|
|
Search for case most similar to query
|
2022-12-13 13:25:31 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
r = {"case": None, "sim": -1.0, "sim_per_field": dict()}
|
|
|
|
|
2023-01-11 06:58:45 +00:00
|
|
|
for case in self.cases:
|
2022-12-13 13:25:31 +00:00
|
|
|
_sim = 0.0
|
|
|
|
_sim_per_field = dict()
|
2023-01-11 06:58:45 +00:00
|
|
|
|
2022-12-13 13:25:31 +00:00
|
|
|
for field, sim_func in fields_and_sim_funcs.items():
|
|
|
|
|
2023-01-11 06:58:45 +00:00
|
|
|
# Some columns contain special chars like spaces or brackets:
|
|
|
|
# radius_curve -> radius_curve(m)
|
|
|
|
# speed_limit -> speed_limit(km/h)
|
2022-12-14 10:33:13 +00:00
|
|
|
field_name = ""
|
|
|
|
if field == "radius_curve":
|
|
|
|
field_name = "radius_curve(m)"
|
|
|
|
elif field == "speed_limit":
|
|
|
|
field_name = "speed_limit(km/h)"
|
|
|
|
else:
|
|
|
|
field_name = field
|
|
|
|
|
2023-01-11 06:58:45 +00:00
|
|
|
if sim_func == sim.symbolic_sim:
|
2022-12-13 13:25:31 +00:00
|
|
|
field_sim = (
|
2023-01-11 06:58:45 +00:00
|
|
|
field_name,
|
2022-12-13 13:25:31 +00:00
|
|
|
sim_func(
|
|
|
|
query.problem[field],
|
2022-12-14 10:33:13 +00:00
|
|
|
case.problem[field_name],
|
2022-12-13 13:25:31 +00:00
|
|
|
self.get_symbolic_sim(field)
|
|
|
|
)
|
|
|
|
)
|
2023-01-11 06:58:45 +00:00
|
|
|
|
|
|
|
elif sim_func == sim.euclid_sim:
|
2022-12-13 13:25:31 +00:00
|
|
|
field_sim = (
|
2023-01-11 06:58:45 +00:00
|
|
|
field_name,
|
2022-12-13 13:25:31 +00:00
|
|
|
sim_func(
|
2023-01-11 06:58:45 +00:00
|
|
|
query.problem[field],
|
2022-12-14 10:33:13 +00:00
|
|
|
case.problem[field_name]
|
2022-12-13 13:25:31 +00:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
_sim_per_field[field_sim[0]] = field_sim[1]
|
|
|
|
_sim += field_sim[1]
|
|
|
|
|
|
|
|
if _sim > r["sim"]:
|
|
|
|
r = {
|
|
|
|
"case": case,
|
|
|
|
"sim": _sim,
|
|
|
|
"sim_per_field": _sim_per_field
|
|
|
|
}
|
|
|
|
|
|
|
|
return RetrievedCase(
|
2023-01-11 06:58:45 +00:00
|
|
|
r["case"].problem,
|
2022-12-13 13:25:31 +00:00
|
|
|
r["case"].solution,
|
|
|
|
r["sim"],
|
|
|
|
r["sim_per_field"]
|
|
|
|
)
|
|
|
|
|
2023-01-11 06:58:45 +00:00
|
|
|
def add_symbolic_sim(self, field: str, similarity_matrix: dict):
|
2023-01-10 11:49:44 +00:00
|
|
|
"""
|
2023-01-11 06:58:45 +00:00
|
|
|
Add similarity matrix for field
|
2023-01-10 11:49:44 +00:00
|
|
|
"""
|
|
|
|
|
2022-12-13 13:25:31 +00:00
|
|
|
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
|
|
|
|
raise ValueError(f"unknown field {field}")
|
|
|
|
|
2023-01-11 06:58:45 +00:00
|
|
|
if self.__field_infos is None:
|
2022-12-14 10:33:13 +00:00
|
|
|
self.__field_infos = {}
|
|
|
|
|
|
|
|
self.__field_infos[field] = {
|
2023-01-11 06:58:45 +00:00
|
|
|
"symbolic_sims": similarity_matrix
|
|
|
|
}
|
2022-12-14 10:33:13 +00:00
|
|
|
|
2022-12-13 13:25:31 +00:00
|
|
|
def get_symbolic_sim(self, field: str) -> dict[str]:
|
2023-01-11 06:58:45 +00:00
|
|
|
"""
|
|
|
|
Get similarity matrix by field
|
|
|
|
"""
|
|
|
|
return self.__field_infos[field]["symbolic_sims"]
|