CaseBasedReasoning/model.py

224 lines
6.7 KiB
Python
Raw Normal View History

2022-12-13 13:25:31 +00:00
from dataclasses import dataclass
2023-01-11 06:58:45 +00:00
import csv
2022-12-13 13:25:31 +00:00
from typing import Callable
import similarity as sim
2023-01-11 06:58:45 +00:00
2022-12-13 13:25:31 +00:00
@dataclass
class Entity:
problem: dict[str]
solution: dict[str]
@classmethod
def from_dict(cls, problem: dict, solution: dict = None):
return cls(problem, solution) if solution else cls(problem, dict())
class Query(Entity):
@classmethod
def from_problems(cls, **problem):
2023-01-11 06:58:45 +00:00
"""
Returns a query object by providing problems as keyword arguments
"""
2022-12-13 13:25:31 +00:00
return cls(problem, solution=dict())
class Case(Entity):
pass
2023-01-11 06:58:45 +00:00
2022-12-13 13:25:31 +00:00
@dataclass
class RetrievedCase(Case):
similarity: float
sim_per_field: list[tuple]
def __str__(self) -> str:
return " ".join(self.solution.values()).capitalize()
@dataclass
class CaseBase:
2023-01-11 06:58:45 +00:00
"""
A CaseBase object represents a collection of all known cases
"""
2022-12-13 13:25:31 +00:00
cases: list[Case]
config: dict[str]
fields: tuple[str]
__field_infos: dict[str] = None
def __repr__(self) -> str:
key_vals = f"cases={len(self.cases)}, fields={list(self.fields.values())}"
return f"{self.__class__.__name__}({key_vals})"
def _default_cfg(new_cfg: dict) -> dict:
2023-01-11 06:58:45 +00:00
"""
apply default configuration if not overwritten
"""
2022-12-13 13:25:31 +00:00
default_config = {
"encoding": "utf-8",
"delimiter": ",",
"set_int": False
}.items()
return {k: v if k not in new_cfg else new_cfg[k] for k, v in default_config}
def _loader(reader_obj: list[dict], set_int: bool, **kwargs):
2023-01-11 06:58:45 +00:00
"""
helper function for reading. Only use internally!
"""
2022-12-13 13:25:31 +00:00
_cases = list()
for elem in reader_obj:
_problem, _solution = dict(), dict()
for key, value in elem.items():
if key in kwargs["problem_fields"]:
if set_int:
try:
_problem[key] = float(value)
except ValueError:
try:
_problem[key] = float(value.replace(",", "."))
except ValueError:
_problem[key] = value
elif key in kwargs["solution_fields"]:
if set_int:
try:
_solution[key] = float(value)
except ValueError:
try:
_solution[key] = float(value.replace(",", "."))
except ValueError:
2023-01-11 06:58:45 +00:00
_solution[key] = value
2022-12-13 13:25:31 +00:00
_cases.append(
Case.from_dict(_problem, _solution)
)
2023-01-11 06:58:45 +00:00
2022-12-13 13:25:31 +00:00
return _cases
@classmethod
def from_csv(cls, path: str, problem_fields: list, solution_fields: list, **cfg):
"""
read a csv file and load every column by `problem_fields` and `solution_fields`
Args:
path: path to a valid .csv file
problem_fields: list with columns to be considered as problem
solution_fields: list with columns to be considered as solution
**cfg: overwrite default configuration (see CaseBase._default_cfg())
Returns:
A CaseBase object with:
- a list of cases
- the used configuration
- a tuple with
[0] -> problem_fields
[1] -> solution_fields
2023-01-11 06:58:45 +00:00
2022-12-13 13:25:31 +00:00
Raises:
ValueError: if the passed path isnt a .csv file
"""
if not path.endswith(".csv"):
raise ValueError("invalid file format:", path)
cfg = cls._default_cfg(cfg)
with open(path, encoding=cfg["encoding"]) as file:
cases = cls._loader(
2023-01-11 06:58:45 +00:00
reader_obj=csv.DictReader(file, delimiter=cfg["delimiter"]),
set_int=cfg["set_int"],
problem_fields=problem_fields,
solution_fields=solution_fields
2022-12-13 13:25:31 +00:00
)
2023-01-11 06:58:45 +00:00
2022-12-13 13:25:31 +00:00
return cls(
2023-01-11 06:58:45 +00:00
cases=cases,
config=cfg,
fields={
2022-12-13 13:25:31 +00:00
"problem": problem_fields,
"solution": solution_fields
}
)
2023-01-11 06:58:45 +00:00
def retrieve(self, query: Query, **fields_and_sim_funcs: dict[str, Callable]) -> RetrievedCase:
2022-12-13 13:25:31 +00:00
"""
2023-01-11 06:58:45 +00:00
Search for case most similar to query
2022-12-13 13:25:31 +00:00
"""
r = {"case": None, "sim": -1.0, "sim_per_field": dict()}
2023-01-11 06:58:45 +00:00
for case in self.cases:
2022-12-13 13:25:31 +00:00
_sim = 0.0
_sim_per_field = dict()
2023-01-11 06:58:45 +00:00
2022-12-13 13:25:31 +00:00
for field, sim_func in fields_and_sim_funcs.items():
2023-01-11 06:58:45 +00:00
# Some columns contain special chars like spaces or brackets:
# radius_curve -> radius_curve(m)
# speed_limit -> speed_limit(km/h)
2022-12-14 10:33:13 +00:00
field_name = ""
if field == "radius_curve":
field_name = "radius_curve(m)"
elif field == "speed_limit":
field_name = "speed_limit(km/h)"
else:
field_name = field
2023-01-11 06:58:45 +00:00
if sim_func == sim.symbolic_sim:
2022-12-13 13:25:31 +00:00
field_sim = (
2023-01-11 06:58:45 +00:00
field_name,
2022-12-13 13:25:31 +00:00
sim_func(
query.problem[field],
2022-12-14 10:33:13 +00:00
case.problem[field_name],
2022-12-13 13:25:31 +00:00
self.get_symbolic_sim(field)
)
)
2023-01-11 06:58:45 +00:00
elif sim_func == sim.euclid_sim:
2022-12-13 13:25:31 +00:00
field_sim = (
2023-01-11 06:58:45 +00:00
field_name,
2022-12-13 13:25:31 +00:00
sim_func(
2023-01-11 06:58:45 +00:00
query.problem[field],
2022-12-14 10:33:13 +00:00
case.problem[field_name]
2022-12-13 13:25:31 +00:00
)
)
_sim_per_field[field_sim[0]] = field_sim[1]
_sim += field_sim[1]
if _sim > r["sim"]:
r = {
"case": case,
"sim": _sim,
"sim_per_field": _sim_per_field
}
return RetrievedCase(
2023-01-11 06:58:45 +00:00
r["case"].problem,
2022-12-13 13:25:31 +00:00
r["case"].solution,
r["sim"],
r["sim_per_field"]
)
2023-01-11 06:58:45 +00:00
def add_symbolic_sim(self, field: str, similarity_matrix: dict):
2023-01-10 11:49:44 +00:00
"""
2023-01-11 06:58:45 +00:00
Add similarity matrix for field
2023-01-10 11:49:44 +00:00
"""
2022-12-13 13:25:31 +00:00
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
raise ValueError(f"unknown field {field}")
2023-01-11 06:58:45 +00:00
if self.__field_infos is None:
2022-12-14 10:33:13 +00:00
self.__field_infos = {}
self.__field_infos[field] = {
2023-01-11 06:58:45 +00:00
"symbolic_sims": similarity_matrix
}
2022-12-14 10:33:13 +00:00
2022-12-13 13:25:31 +00:00
def get_symbolic_sim(self, field: str) -> dict[str]:
2023-01-11 06:58:45 +00:00
"""
Get similarity matrix by field
"""
return self.__field_infos[field]["symbolic_sims"]