CaseBasedReasoning/model.py

292 lines
9.1 KiB
Python

from dataclasses import dataclass
import csv, json
from typing import Callable
import similarity as sim
@dataclass
class Entity:
problem: dict[str]
solution: dict[str]
@classmethod
def from_dict(cls, problem: dict, solution: dict = None):
return cls(problem, solution) if solution else cls(problem, dict())
class Query(Entity):
"""
A query is represented as an Entity with:
- problem (e.g. {"attr_1": 123, "attr_2": 456})
- soution = {}
"""
@classmethod
def from_problems(cls, **problem):
"""Returns a query object by providing problems as keyword arguments"""
return cls(problem, solution=dict())
class Case(Entity):
"""
A case is represented as an Entity with:
- problem (e.g. {"attr_1": 123, "attr_2": 456})
- solution (e.g. see problem)
"""
pass
@dataclass
class RetrievedCase(Case):
similarity: float
sim_per_field: list[tuple]
def __str__(self) -> str:
return " ".join(self.solution.values()).capitalize()
###############################################################################
@dataclass
class CaseBase:
"""A CaseBase object represents a collection of all known cases"""
cases: list[Case]
config: dict[str]
fields: tuple[str]
__field_infos: dict[str] = None
def __repr__(self) -> str:
key_vals = f"cases={len(self.cases)}, fields={list(self.fields.values())}"
return f"{self.__class__.__name__}({key_vals})"
def __getitem__(self, key: int) -> Case:
return self.cases[key]
def _default_cfg(new_cfg: dict) -> dict:
"""apply default configuration if not overwritten"""
default_config = {
"encoding": "utf-8",
"delimiter": ",",
"set_int": False
}.items()
return {k: v if k not in new_cfg else new_cfg[k] for k, v in default_config}
def _loader(reader_obj: list[dict], set_int: bool, **kwargs):
"""helper function for reading. Only use internally!"""
_cases = list()
for elem in reader_obj:
_problem, _solution = dict(), dict()
for key, value in elem.items():
if key in kwargs["problem_fields"]:
if set_int:
try:
_problem[key] = float(value)
except ValueError:
try:
_problem[key] = float(value.replace(",", "."))
except ValueError:
_problem[key] = value
elif key in kwargs["solution_fields"]:
if set_int:
try:
_solution[key] = float(value)
except ValueError:
try:
_solution[key] = float(value.replace(",", "."))
except ValueError:
_solution[key] = value
_cases.append(
Case.from_dict(_problem, _solution)
)
return _cases
@classmethod
def from_csv(cls, path: str, problem_fields: list, solution_fields: list, **cfg):
"""
read a csv file and load every column by `problem_fields` and `solution_fields`
Args:
path: path to a valid .csv file
problem_fields: list with columns to be considered as problem
solution_fields: list with columns to be considered as solution
**cfg: overwrite default configuration (see CaseBase._default_cfg())
Returns:
A CaseBase object with:
- a list of cases
- the used configuration
- a tuple with
[0] -> problem_fields
[1] -> solution_fields
Raises:
ValueError: if the passed path isnt a .csv file
"""
if not path.endswith(".csv"):
raise ValueError("invalid file format:", path)
cfg = cls._default_cfg(cfg)
with open(path, encoding=cfg["encoding"]) as file:
cases = cls._loader(
reader_obj = csv.DictReader(file, delimiter = cfg["delimiter"]),
set_int = cfg["set_int"],
problem_fields = problem_fields,
solution_fields = solution_fields
)
return cls(
cases = cases,
config = cfg,
fields = {
"problem": problem_fields,
"solution": solution_fields
}
)
@classmethod
def from_json(cls, path: str, problem_fields: list, solution_fields: list, **cfg):
"""
read a json file and load every column by `problem_fields` and `solution_fields`
Args:
path: path to a valid .json file (array of json-objects)
problem_fields: list with columns to be considered as problem
solution_fields: list with columns to be considered as solution
**cfg: overwrite default configuration (see CaseBase._default_cfg())
Returns:
A CaseBase object with:
- a list of cases
- the used configuration
- a tuple with
[0] -> problem_fields
[1] -> solution_fields
Raises:
ValueError: if the passed path isnt a .csv file
"""
if not path.endswith(".json"):
raise ValueError("invalid file format:", path)
cfg = cls._default_cfg(cfg)
with open(path, encoding=cfg["encoding"]) as file:
cases = cls._loader(
reader_obj = json.load(file),
set_int = cfg["set_int"],
problem_fields = problem_fields,
solution_fields = solution_fields
)
return cls(
cases = cases,
config = cfg
)
def retrieve(self, query: Query, **fields_and_sim_funcs: dict[str, Callable]) -> RetrievedCase:
"""Search for case most similar to query"""
r = {"case": None, "sim": -1.0, "sim_per_field": dict()}
for case in self.cases:
_sim = 0.0
_sim_per_field = dict()
for field, sim_func in fields_and_sim_funcs.items():
# Some columns contain special chars.
field_name = ""
if field == "radius_curve":
field_name = "radius_curve(m)"
elif field == "speed_limit":
field_name = "speed_limit(km/h)"
else:
field_name = field
if sim_func in sim.SYMBOLIC_SIMS:
field_sim = (
field_name,
sim_func(
query.problem[field],
case.problem[field_name],
self.get_symbolic_sim(field)
)
)
elif sim_func in sim.METRIC_SIMS:
field_sim = (
field_name,
sim_func(
query.problem[field],
case.problem[field_name]
)
)
_sim_per_field[field_sim[0]] = field_sim[1]
_sim += field_sim[1]
if _sim > r["sim"]:
r = {
"case": case,
"sim": _sim,
"sim_per_field": _sim_per_field
}
return RetrievedCase(
r["case"].problem,
r["case"].solution,
r["sim"],
r["sim_per_field"]
)
def get_values_by_field(self, field: str) -> set[str]:
"""
get the distinct values by the fields
Args:
self: List with corresponding values
field: a string value that must be contained in self
Returns:
the distinct values
Raises:
ValueError: if the field value not in the 'problem' or 'solution' list
"""
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
raise ValueError(f"unknown field {field}")
distinct_values = set()
for elem in self.cases:
if field in elem.problem:
distinct_values.add(elem.problem[field])
elif field in elem.solution:
distinct_values.add(elem.solution[field])
return distinct_values
def add_symbolic_sim(self, field: str, similarity_matrix: dict):
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
raise ValueError(f"unknown field {field}")
if self.__field_infos is None:
self.__field_infos = {}
self.__field_infos[field] = {
"symbolic_sims": similarity_matrix
}
def get_symbolic_sim(self, field: str) -> dict[str]:
return self.__field_infos[field]["symbolic_sims"]