Refactoring, documentation
This commit is contained in:
parent
6c79ba9379
commit
bcb7f39372
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
venv/
|
venv/
|
||||||
|
env/
|
||||||
__pycache__
|
__pycache__
|
||||||
|
176
explore.ipynb
176
explore.ipynb
@ -10,18 +10,18 @@
|
|||||||
"Es soll vorausgesagt werden, ob ein Fahrzeug in einer gegebenen Verkehrssituation das vorausfahrende Fahrzeug überholen wird. Dazu werden der Abstand und die Geschwindigkeit zum vorausfahrenden Fahrzeug gemessen. Zudem wird diese Information von einem Fahrzeug auf der linken Spur gemessen, welches sich von hinten nähert. Als weitere Informationen liegen die jeweiligen Fahrzeugtypen, das Wetter, die Tageszeit und die Straßenart und -krümmung vor.\n",
|
"Es soll vorausgesagt werden, ob ein Fahrzeug in einer gegebenen Verkehrssituation das vorausfahrende Fahrzeug überholen wird. Dazu werden der Abstand und die Geschwindigkeit zum vorausfahrenden Fahrzeug gemessen. Zudem wird diese Information von einem Fahrzeug auf der linken Spur gemessen, welches sich von hinten nähert. Als weitere Informationen liegen die jeweiligen Fahrzeugtypen, das Wetter, die Tageszeit und die Straßenart und -krümmung vor.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Mögliche vorherzusagende Reaktionen des Fahrzeugs sind: Geschwindigkeit und Spur beibehalten (continue), Verzögern (decelerate), Spurwechsel bei gleicher Geschwindigkeit (lane change) und Spurwechsel mit Beschleunigung (accelerated lane change).\n",
|
"Mögliche vorherzusagende Reaktionen des Fahrzeugs sind: Geschwindigkeit und Spur beibehalten (continue), Verzögern (decelerate), Spurwechsel bei gleicher Geschwindigkeit (lane change) und Spurwechsel mit Beschleunigung (accelerated lane change).\n",
|
||||||
"Entwerfen Sie dazu ein geeignetes Ähnlichkeitsmaß und ermitteln Sie eine Repräsentation des Modells mit Hilfe von Case Based Learning (Verfahren aus der Vorlesung)\n",
|
"Entwerfen Sie dazu ein geeignetes Ähnlichkeitsmaß und ermitteln Sie eine Repräsentation des Modells mithilfe von Case Based Learning (Verfahren aus der Vorlesung)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"## Aufgabenstellung\n",
|
"## Aufgabenstellung\n",
|
||||||
"Entwickeln Sie eine Software, welche bei Eingabe einer Verkehrssituation (im gleichen Format) die Aktion des Fahrzeugs voraussagen kann. Diskutieren Sie Ihre Konfiguration und das Ergebnis.\n",
|
"Entwickeln Sie eine Software, welche bei Eingabe einer Verkehrssituation (im gleichen Format) die Aktion des Fahrzeugs voraussagen kann. Diskutieren Sie Ihre Konfiguration und das Ergebnis.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"## Eingangsdaten\n",
|
"## Eingangsdaten\n",
|
||||||
"Ein Datensatz von Messungen in welchem ähnliche Situationen aufgezeichnet wurden -> data/SIM_001.csv"
|
"Ein Datensatz von Messungen, in welchem ähnliche Situationen aufgezeichnet wurden -> data/SIM_001.csv"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -33,14 +33,14 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def create_similarity_matrix(filename, key):\n",
|
"def create_similarity_matrix(filename, key):\n",
|
||||||
" \"\"\" \n",
|
" \"\"\" \n",
|
||||||
" Generates similarity matrix \n",
|
" Generates similarity matrix \n",
|
||||||
" \n",
|
"\n",
|
||||||
" Arguments:\n",
|
" Arguments:\n",
|
||||||
" filename: the csv file name as string\n",
|
" filename: the csv file name as string\n",
|
||||||
" key: the title of the csv which is the value in the first line and col\n",
|
" key: the title of the csv which is the value in the first line and col\n",
|
||||||
@ -59,12 +59,12 @@
|
|||||||
" else:\n",
|
" else:\n",
|
||||||
" similarity_matrix[key_v][k] = float(v)\n",
|
" similarity_matrix[key_v][k] = float(v)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return similarity_matrix"
|
" return similarity_matrix\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -83,7 +83,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -133,55 +133,146 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Your Query:\n",
|
||||||
|
" - v = 28.5\n",
|
||||||
|
" - v_left = 42.5\n",
|
||||||
|
" - v_front = 5\n",
|
||||||
|
" - d_left = -137\n",
|
||||||
|
" - d_front = 54\n",
|
||||||
|
" - type_left = motorcycle\n",
|
||||||
|
" - type_front = truck\n",
|
||||||
|
" - radius_curve = 2391\n",
|
||||||
|
" - slope_street = flat\n",
|
||||||
|
" - street_type = country_road (separated)\n",
|
||||||
|
" - time = day\n",
|
||||||
|
" - weather = dry\n",
|
||||||
|
" - type_vehicle = car\n",
|
||||||
|
" - speed_limit = 100\n",
|
||||||
|
"\n",
|
||||||
|
"Prediction: Decelerate\n",
|
||||||
|
"Probability: 92.75%\n",
|
||||||
|
"\n",
|
||||||
|
"Explanation:\n",
|
||||||
|
" - v_left = 42.5 (similarity: 1.00)\n",
|
||||||
|
" - v_front = 21.5 (similarity: 0.06)\n",
|
||||||
|
" - d_left = -137.0 (similarity: 1.00)\n",
|
||||||
|
" - d_front = 54.0 (similarity: 1.00)\n",
|
||||||
|
" - type_left = motorcycle (similarity: 1.00)\n",
|
||||||
|
" - type_front = truck (similarity: 1.00)\n",
|
||||||
|
" - radius_curve(m) = 2391.0 (similarity: 1.00)\n",
|
||||||
|
" - slope_street = flat (similarity: 1.00)\n",
|
||||||
|
" - street_type = country_road (separated) (similarity: 1.00)\n",
|
||||||
|
" - time = day (similarity: 1.00)\n",
|
||||||
|
" - weather = dry (similarity: 1.00)\n",
|
||||||
|
" - type_vehicle = car (similarity: 1.00)\n",
|
||||||
|
" - speed_limit(km/h) = 100.0 (similarity: 1.00)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# TODO Read query from cli\n",
|
"# Read query from cli\n",
|
||||||
"query = Query.from_problems(\n",
|
"query = Query.from_problems(\n",
|
||||||
" v = 28.5,\n",
|
" v=28.5,\n",
|
||||||
" v_left = 42.5,\n",
|
" v_left=42.5,\n",
|
||||||
" v_front = 5,\n",
|
" v_front=5,\n",
|
||||||
" d_left = -137,\n",
|
" d_left=-137,\n",
|
||||||
" d_front = 54,\n",
|
" d_front=54,\n",
|
||||||
" type_left = \"motorcycle\",\n",
|
" type_left=\"motorcycle\",\n",
|
||||||
" type_front = \"truck\",\n",
|
" type_front=\"truck\",\n",
|
||||||
" radius_curve = 2391,\n",
|
" radius_curve=2391,\n",
|
||||||
" slope_street = \"flat\",\n",
|
" slope_street=\"flat\",\n",
|
||||||
" street_type = \"country_road (separated)\",\n",
|
" street_type=\"country_road (separated)\",\n",
|
||||||
" time = \"day\",\n",
|
" time=\"day\",\n",
|
||||||
" weather = \"dry\",\n",
|
" weather=\"dry\",\n",
|
||||||
" type_vehicle = \"car\",\n",
|
" type_vehicle=\"car\",\n",
|
||||||
" speed_limit = 100,\n",
|
" speed_limit=100,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# sim_funcs: manhattan_sim, euclid_sim\n",
|
|
||||||
"retrieved = case_base_obj.retrieve(\n",
|
"retrieved = case_base_obj.retrieve(\n",
|
||||||
" query,\n",
|
" query,\n",
|
||||||
" v_left = euclid_sim,\n",
|
" v_left=euclid_sim,\n",
|
||||||
" v_front = euclid_sim,\n",
|
" v_front=euclid_sim,\n",
|
||||||
" d_left = euclid_sim,\n",
|
" d_left=euclid_sim,\n",
|
||||||
" d_front = euclid_sim,\n",
|
" d_front=euclid_sim,\n",
|
||||||
" type_left = symbolic_sim,\n",
|
" type_left=symbolic_sim,\n",
|
||||||
" type_front = symbolic_sim,\n",
|
" type_front=symbolic_sim,\n",
|
||||||
" radius_curve = euclid_sim,\n",
|
" radius_curve=euclid_sim,\n",
|
||||||
" slope_street = symbolic_sim,\n",
|
" slope_street=symbolic_sim,\n",
|
||||||
" street_type = symbolic_sim,\n",
|
" street_type=symbolic_sim,\n",
|
||||||
" time = symbolic_sim,\n",
|
" time=symbolic_sim,\n",
|
||||||
" weather = symbolic_sim,\n",
|
" weather=symbolic_sim,\n",
|
||||||
" type_vehicle = symbolic_sim,\n",
|
" type_vehicle=symbolic_sim,\n",
|
||||||
" speed_limit = euclid_sim,\n",
|
" speed_limit=euclid_sim,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Your Query:\")\n",
|
"print(\"Your Query:\")\n",
|
||||||
"for k, v in query.problem.items():\n",
|
"for k, v in query.problem.items():\n",
|
||||||
" print(f\" - {k} = {v}\")\n",
|
" print(f\" - {k} = {v}\")\n",
|
||||||
"print()\n",
|
"print()\n",
|
||||||
"print(\"Prediction: \" + \" \".join(retrieved.solution.values()).capitalize())\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Prediction: {' ' .join(retrieved.solution.values()).capitalize()}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Calculate probability of prediction\n",
|
||||||
|
"prob = 0\n",
|
||||||
|
"count = 0\n",
|
||||||
|
"for field, val in retrieved.sim_per_field.items():\n",
|
||||||
|
" count += 1\n",
|
||||||
|
" prob += val\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Probability: {((prob / count) * 100):.2f}%\")\n",
|
||||||
"print()\n",
|
"print()\n",
|
||||||
|
"\n",
|
||||||
"print(\"Explanation:\")\n",
|
"print(\"Explanation:\")\n",
|
||||||
"for field, val in retrieved.sim_per_field.items():\n",
|
"for field, val in retrieved.sim_per_field.items():\n",
|
||||||
" print(f\" - {field} =\", retrieved.problem[field], f\"(similarity: {val:.2f})\")"
|
" print(f\" - {field} =\",\n",
|
||||||
|
" retrieved.problem[field], f\"(similarity: {val:.2f})\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Entwurf und Umsetzung\n",
|
||||||
|
"\n",
|
||||||
|
"Allgemein ist ein Case-Based Reasoning (CBR) Algorithmus eine Methode des maschinellen Lernens, bei der Probleme durch den Vergleich mit früheren, ähnlichen Fällen gelöst werden. Ein CBR-System arbeitet, indem es die Merkmale des aktuellen Problems mit gespeicherten Fällen vergleicht und auf Basis dieser Vergleiche eine Lösung vorschlägt. Dadurch kann der Algorithmus in einer Vielzahl von Anwendungsbereichen, die auch einen hohen Grad an Expertise erfordern können, zum Einsatz kommen. \n",
|
||||||
|
"\n",
|
||||||
|
"Beim Entwurf des Algorithmus wird darauf geachtet dass dieser auch einfach auf andere Problemstellungen angepasst werden kann. Darüber hinaus wird die Lösung in der Sprache Python umgesetzt werden.\n",
|
||||||
|
"\n",
|
||||||
|
"![Alt-Text](./KI.png \"Entwurf des Algorithmus\")\n",
|
||||||
|
"\n",
|
||||||
|
"Wie in der obigen Abbildung zu erkennen, wird der Algorithmus in drei Bereiche unterteilt. Dies ist erstens das Notebook, dadurch kann der Nutzer komfortabel auf die von ihm benötigten Funktionen zurückgreifen. Die Funktionalität wird über die model.py und die similarity.py bereitgestellt. Dabei ist die model.py das eigentliche Modell, welches über das Laden der CSV Dateien kalibriert wird. In der similarity.py befinden sich die Algorithmen, die für den Vergleich der Eingabeparameter benötigt werden.\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"## Diskussion der Konfiguration\n",
|
||||||
|
"\n",
|
||||||
|
"Wie bereits unter dem Punkt \"Entwurf und Umsetzung\" erwähnt, wird das Modell in diesem Fall über CSV Dateien kalibriert. Einerseits werden die generellen Datensätze geladen, die sich in der SIM_001.csv befinden, aber auch die Ähnlichkeitsmatrizen. Erst durch diese Ähnlichkeitsmatrizen wird dem Modell ermöglicht, die in Textform, zur Verfügung stehenden Features im Datensatz zu interpretieren und auch sicher auf andere Situationen anzuwenden. Denn durch diese wird der Zusammenhang der verschiedenen Auftrittsmöglichkeiten festgelegt. \n",
|
||||||
|
"\n",
|
||||||
|
"Festgelegt werden müssen in diesem Anwendungsfall die Beziehung zwischen den verschiedenen Straßenneigungen (flat, ascending, decending), des Straßentyps (country_road (separated), autobahn), des Zeittyps (night, dusk, day, dawn), des Fahrzeugtyps (car, motorcycle, sportscar, truck) und der Wetterart (dry, rain, fog, snow_ice).\n",
|
||||||
|
"\n",
|
||||||
|
"Die Erstellung der Ähnlichkeitsmatrix für die Straßenneigung wird nun im Folgenden näher erläutert:\n",
|
||||||
|
"\n",
|
||||||
|
"| type_street_slope | flat | ascending | decending |\n",
|
||||||
|
"| ----| ----| ----| ----|\n",
|
||||||
|
"| <b>flat</b> | 1.0 | 0.3 | 0.7 |\n",
|
||||||
|
"| <b>ascending</b> | 0.3 | 1.0 | 0.1 |\n",
|
||||||
|
"| <b>decending</b> | 0.7 | 0.1 | 1.0 |\n",
|
||||||
|
"\n",
|
||||||
|
"Das Ähnlichkeitsmaß ist eine Zahl zwischen null und eins, wobei null das Gegenteil und eins das gleiche wieder spiegelt. Wie zu erkennen ist, befindet sich in der Diagonale an jeder Stelle eine eins. Dies ist der Fall, da flat mit flat verglichen komplett gleich ist. Nun könnte man annehmen, dass bei den möglichen Zuständen flach, aufsteigend und absteigend der Vergleich zwischen flach und aufsteigend mit 0.5 bewertet wird. Wie in der Matrix zu erkennen, beträgt er jedoch 0.3. Denn die Bewertung muss im Kontext des Anwendungsfalls erfolgen, auf einer aufsteigenden Straße zu überholen ist deutlich risikoreicher als auf einer flachen oder sogar absteigenden Straße. Es ist damit zu rechnen, dass auf einer aufsteigenden Straße die Beschleunigung länger dauert, gleichzeitig rollt der Gegenverkehr den Berg herunter und fährt dadurch eventuell auch etwas zu schnell, da er weniger Widerstand hat. Dadurch ergibt sich eine Ähnlichkeit von 0.3. Eine flache zu einer absteigenden Straße haben wir mit 0.7 bewertet. Mit der Begründung, dass ein Überholvorgang auf einer flachen und einer absteigenden Straße ein ähnlich hohes Risiko mit sich bringt. Auf die Beschreibung des Vergleiches der restlichen Werte wird an dieser Stelle verzichtet, da sich das Vorgehen nur wiederholen würde.\n",
|
||||||
|
"\n",
|
||||||
|
"## Diskussion des Ergbenisses\n",
|
||||||
|
"\n",
|
||||||
|
"Durch die Aufteilung des Datensatzes (SIM_001.csv) in Trainings- und Testdaten kann die einwandfreie Funktion des Algorithmus sichergestellt werden. Denn das zu erwartende Ergebnis ist bei der Eingabe der dazugehörenden Parameter bereits bekannt. Bei den Tests des Algorithmus wurden alle Testcases richtig vorhergesagt. Dadurch kann von einer guten Genauigkeit des Algorithmus ausgegangen werden. \n",
|
||||||
|
"\n",
|
||||||
|
"Zusammenfassend vereinen Case-Based Reasoning (CBR) Algorithmen einige große Vorteile, die auch in diesem Anwendungsfall zum Vorschein getreten sind. Dazu gehört beispielsweise die Möglichkeit auch komplexe Probleme wie die Entscheidung eines Überholvorgang durch den Vergleich mit ähnlichen Fällen aus der Vergangenheit lösen zu können. Dadurch wird aber auch einer der Nachteile eines solchen Algorithmus deutlich, die Qualität hängt von der Güte und der Menge an zur Verfügung stehenden Daten ab. Auffallend ist, dass der Algorithmus sehr anpassbar und schnell an spezifische Anforderungen und Umstände angepasst werden kann. "
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -203,7 +294,6 @@
|
|||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.0"
|
"version": "3.11.0"
|
||||||
},
|
},
|
||||||
"orig_nbformat": 4,
|
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
"hash": "c6b11de3c41b7cafaa0ac1297b550056ae3875bbf0c337fa48ab4f33656fc527"
|
"hash": "c6b11de3c41b7cafaa0ac1297b550056ae3875bbf0c337fa48ab4f33656fc527"
|
||||||
@ -211,5 +301,5 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 4
|
||||||
}
|
}
|
||||||
|
177
model.py
177
model.py
@ -1,8 +1,9 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import csv, json
|
import csv
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
import similarity as sim
|
import similarity as sim
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Entity:
|
class Entity:
|
||||||
problem: dict[str]
|
problem: dict[str]
|
||||||
@ -14,25 +15,18 @@ class Entity:
|
|||||||
|
|
||||||
|
|
||||||
class Query(Entity):
|
class Query(Entity):
|
||||||
"""
|
|
||||||
A query is represented as an Entity with:
|
|
||||||
- problem (e.g. {"attr_1": 123, "attr_2": 456})
|
|
||||||
- soution = {}
|
|
||||||
"""
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_problems(cls, **problem):
|
def from_problems(cls, **problem):
|
||||||
"""Returns a query object by providing problems as keyword arguments"""
|
"""
|
||||||
|
Returns a query object by providing problems as keyword arguments
|
||||||
|
"""
|
||||||
return cls(problem, solution=dict())
|
return cls(problem, solution=dict())
|
||||||
|
|
||||||
|
|
||||||
class Case(Entity):
|
class Case(Entity):
|
||||||
"""
|
|
||||||
A case is represented as an Entity with:
|
|
||||||
- problem (e.g. {"attr_1": 123, "attr_2": 456})
|
|
||||||
- solution (e.g. see problem)
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RetrievedCase(Case):
|
class RetrievedCase(Case):
|
||||||
similarity: float
|
similarity: float
|
||||||
@ -41,13 +35,12 @@ class RetrievedCase(Case):
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return " ".join(self.solution.values()).capitalize()
|
return " ".join(self.solution.values()).capitalize()
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CaseBase:
|
class CaseBase:
|
||||||
"""A CaseBase object represents a collection of all known cases"""
|
"""
|
||||||
|
A CaseBase object represents a collection of all known cases
|
||||||
|
"""
|
||||||
cases: list[Case]
|
cases: list[Case]
|
||||||
config: dict[str]
|
config: dict[str]
|
||||||
fields: tuple[str]
|
fields: tuple[str]
|
||||||
@ -57,11 +50,10 @@ class CaseBase:
|
|||||||
key_vals = f"cases={len(self.cases)}, fields={list(self.fields.values())}"
|
key_vals = f"cases={len(self.cases)}, fields={list(self.fields.values())}"
|
||||||
return f"{self.__class__.__name__}({key_vals})"
|
return f"{self.__class__.__name__}({key_vals})"
|
||||||
|
|
||||||
def __getitem__(self, key: int) -> Case:
|
|
||||||
return self.cases[key]
|
|
||||||
|
|
||||||
def _default_cfg(new_cfg: dict) -> dict:
|
def _default_cfg(new_cfg: dict) -> dict:
|
||||||
"""apply default configuration if not overwritten"""
|
"""
|
||||||
|
apply default configuration if not overwritten
|
||||||
|
"""
|
||||||
default_config = {
|
default_config = {
|
||||||
"encoding": "utf-8",
|
"encoding": "utf-8",
|
||||||
"delimiter": ",",
|
"delimiter": ",",
|
||||||
@ -71,7 +63,9 @@ class CaseBase:
|
|||||||
return {k: v if k not in new_cfg else new_cfg[k] for k, v in default_config}
|
return {k: v if k not in new_cfg else new_cfg[k] for k, v in default_config}
|
||||||
|
|
||||||
def _loader(reader_obj: list[dict], set_int: bool, **kwargs):
|
def _loader(reader_obj: list[dict], set_int: bool, **kwargs):
|
||||||
"""helper function for reading. Only use internally!"""
|
"""
|
||||||
|
helper function for reading. Only use internally!
|
||||||
|
"""
|
||||||
_cases = list()
|
_cases = list()
|
||||||
for elem in reader_obj:
|
for elem in reader_obj:
|
||||||
_problem, _solution = dict(), dict()
|
_problem, _solution = dict(), dict()
|
||||||
@ -87,7 +81,6 @@ class CaseBase:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
_problem[key] = value
|
_problem[key] = value
|
||||||
|
|
||||||
|
|
||||||
elif key in kwargs["solution_fields"]:
|
elif key in kwargs["solution_fields"]:
|
||||||
if set_int:
|
if set_int:
|
||||||
try:
|
try:
|
||||||
@ -96,11 +89,11 @@ class CaseBase:
|
|||||||
try:
|
try:
|
||||||
_solution[key] = float(value.replace(",", "."))
|
_solution[key] = float(value.replace(",", "."))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
_solution[key] = value
|
_solution[key] = value
|
||||||
_cases.append(
|
_cases.append(
|
||||||
Case.from_dict(_problem, _solution)
|
Case.from_dict(_problem, _solution)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _cases
|
return _cases
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -121,7 +114,7 @@ class CaseBase:
|
|||||||
- a tuple with
|
- a tuple with
|
||||||
[0] -> problem_fields
|
[0] -> problem_fields
|
||||||
[1] -> solution_fields
|
[1] -> solution_fields
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the passed path isnt a .csv file
|
ValueError: if the passed path isnt a .csv file
|
||||||
"""
|
"""
|
||||||
@ -133,74 +126,37 @@ class CaseBase:
|
|||||||
|
|
||||||
with open(path, encoding=cfg["encoding"]) as file:
|
with open(path, encoding=cfg["encoding"]) as file:
|
||||||
cases = cls._loader(
|
cases = cls._loader(
|
||||||
reader_obj = csv.DictReader(file, delimiter = cfg["delimiter"]),
|
reader_obj=csv.DictReader(file, delimiter=cfg["delimiter"]),
|
||||||
set_int = cfg["set_int"],
|
set_int=cfg["set_int"],
|
||||||
problem_fields = problem_fields,
|
problem_fields=problem_fields,
|
||||||
solution_fields = solution_fields
|
solution_fields=solution_fields
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
cases = cases,
|
cases=cases,
|
||||||
config = cfg,
|
config=cfg,
|
||||||
fields = {
|
fields={
|
||||||
"problem": problem_fields,
|
"problem": problem_fields,
|
||||||
"solution": solution_fields
|
"solution": solution_fields
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_json(cls, path: str, problem_fields: list, solution_fields: list, **cfg):
|
|
||||||
"""
|
|
||||||
read a json file and load every column by `problem_fields` and `solution_fields`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: path to a valid .json file (array of json-objects)
|
|
||||||
problem_fields: list with columns to be considered as problem
|
|
||||||
solution_fields: list with columns to be considered as solution
|
|
||||||
**cfg: overwrite default configuration (see CaseBase._default_cfg())
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A CaseBase object with:
|
|
||||||
- a list of cases
|
|
||||||
- the used configuration
|
|
||||||
- a tuple with
|
|
||||||
[0] -> problem_fields
|
|
||||||
[1] -> solution_fields
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if the passed path isnt a .csv file
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not path.endswith(".json"):
|
|
||||||
raise ValueError("invalid file format:", path)
|
|
||||||
|
|
||||||
cfg = cls._default_cfg(cfg)
|
|
||||||
|
|
||||||
with open(path, encoding=cfg["encoding"]) as file:
|
|
||||||
cases = cls._loader(
|
|
||||||
reader_obj = json.load(file),
|
|
||||||
set_int = cfg["set_int"],
|
|
||||||
problem_fields = problem_fields,
|
|
||||||
solution_fields = solution_fields
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
cases = cases,
|
|
||||||
config = cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve(self, query: Query, **fields_and_sim_funcs: dict[str, Callable]) -> RetrievedCase:
|
def retrieve(self, query: Query, **fields_and_sim_funcs: dict[str, Callable]) -> RetrievedCase:
|
||||||
"""Search for case most similar to query"""
|
"""
|
||||||
|
Search for case most similar to query
|
||||||
|
"""
|
||||||
r = {"case": None, "sim": -1.0, "sim_per_field": dict()}
|
|
||||||
for case in self.cases:
|
|
||||||
|
|
||||||
|
r = {"case": None, "sim": -1.0, "sim_per_field": dict()}
|
||||||
|
|
||||||
|
for case in self.cases:
|
||||||
_sim = 0.0
|
_sim = 0.0
|
||||||
_sim_per_field = dict()
|
_sim_per_field = dict()
|
||||||
|
|
||||||
for field, sim_func in fields_and_sim_funcs.items():
|
for field, sim_func in fields_and_sim_funcs.items():
|
||||||
|
|
||||||
# Some columns contain special chars.
|
# Some columns contain special chars like spaces or brackets:
|
||||||
|
# radius_curve -> radius_curve(m)
|
||||||
|
# speed_limit -> speed_limit(km/h)
|
||||||
field_name = ""
|
field_name = ""
|
||||||
if field == "radius_curve":
|
if field == "radius_curve":
|
||||||
field_name = "radius_curve(m)"
|
field_name = "radius_curve(m)"
|
||||||
@ -209,21 +165,21 @@ class CaseBase:
|
|||||||
else:
|
else:
|
||||||
field_name = field
|
field_name = field
|
||||||
|
|
||||||
if sim_func in sim.SYMBOLIC_SIMS:
|
if sim_func == sim.symbolic_sim:
|
||||||
field_sim = (
|
field_sim = (
|
||||||
field_name,
|
field_name,
|
||||||
sim_func(
|
sim_func(
|
||||||
query.problem[field],
|
query.problem[field],
|
||||||
case.problem[field_name],
|
case.problem[field_name],
|
||||||
self.get_symbolic_sim(field)
|
self.get_symbolic_sim(field)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif sim_func in sim.METRIC_SIMS:
|
elif sim_func == sim.euclid_sim:
|
||||||
field_sim = (
|
field_sim = (
|
||||||
field_name,
|
field_name,
|
||||||
sim_func(
|
sim_func(
|
||||||
query.problem[field],
|
query.problem[field],
|
||||||
case.problem[field_name]
|
case.problem[field_name]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -237,56 +193,31 @@ class CaseBase:
|
|||||||
"sim": _sim,
|
"sim": _sim,
|
||||||
"sim_per_field": _sim_per_field
|
"sim_per_field": _sim_per_field
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return RetrievedCase(
|
return RetrievedCase(
|
||||||
r["case"].problem,
|
r["case"].problem,
|
||||||
r["case"].solution,
|
r["case"].solution,
|
||||||
r["sim"],
|
r["sim"],
|
||||||
r["sim_per_field"]
|
r["sim_per_field"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_values_by_field(self, field: str) -> set[str]:
|
|
||||||
"""
|
|
||||||
get the distinct values by the fields
|
|
||||||
|
|
||||||
Args:
|
|
||||||
self: List with corresponding values
|
|
||||||
field: a string value that must be contained in self
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
the distinct values
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if the field value not in the 'problem' or 'solution' list
|
|
||||||
"""
|
|
||||||
|
|
||||||
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
|
|
||||||
raise ValueError(f"unknown field {field}")
|
|
||||||
|
|
||||||
distinct_values = set()
|
|
||||||
for elem in self.cases:
|
|
||||||
|
|
||||||
if field in elem.problem:
|
|
||||||
distinct_values.add(elem.problem[field])
|
|
||||||
|
|
||||||
elif field in elem.solution:
|
|
||||||
distinct_values.add(elem.solution[field])
|
|
||||||
|
|
||||||
return distinct_values
|
|
||||||
|
|
||||||
|
|
||||||
def add_symbolic_sim(self, field: str, similarity_matrix: dict):
|
def add_symbolic_sim(self, field: str, similarity_matrix: dict):
|
||||||
|
"""
|
||||||
|
Add similarity matrix for field
|
||||||
|
"""
|
||||||
|
|
||||||
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
|
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
|
||||||
raise ValueError(f"unknown field {field}")
|
raise ValueError(f"unknown field {field}")
|
||||||
|
|
||||||
if self.__field_infos is None:
|
if self.__field_infos is None:
|
||||||
self.__field_infos = {}
|
self.__field_infos = {}
|
||||||
|
|
||||||
self.__field_infos[field] = {
|
self.__field_infos[field] = {
|
||||||
"symbolic_sims": similarity_matrix
|
"symbolic_sims": similarity_matrix
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_symbolic_sim(self, field: str) -> dict[str]:
|
def get_symbolic_sim(self, field: str) -> dict[str]:
|
||||||
return self.__field_infos[field]["symbolic_sims"]
|
"""
|
||||||
|
Get similarity matrix by field
|
||||||
|
"""
|
||||||
|
return self.__field_infos[field]["symbolic_sims"]
|
||||||
|
@ -1,83 +1,8 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
"""
|
|
||||||
METRIC SIMILARITY FUNCTIONS
|
|
||||||
"""
|
|
||||||
def manhattan_sim(q_val: float, c_val: float) -> float:
|
|
||||||
m_dist = lambda x, y: abs(x - y)
|
|
||||||
return 1 / (1 + m_dist(q_val, c_val))
|
|
||||||
|
|
||||||
def euclid_sim(q_val: float, c_val: float) -> float:
|
def euclid_sim(q_val: float, c_val: float) -> float:
|
||||||
e_dist = lambda x, y: math.sqrt((x - y)**2)
|
e_dist = lambda x, y: math.sqrt((x - y)**2)
|
||||||
return 1 / (1 + e_dist(q_val, c_val))
|
return 1 / (1 + e_dist(q_val, c_val))
|
||||||
|
|
||||||
|
|
||||||
METRIC_SIMS = [manhattan_sim, euclid_sim]
|
|
||||||
|
|
||||||
"""
|
|
||||||
SYMBOLIC SIMILARITY
|
|
||||||
"""
|
|
||||||
def symbolic_sim(q_field_name: str, c_field_name: str, sim_matrix: dict) -> float:
|
def symbolic_sim(q_field_name: str, c_field_name: str, sim_matrix: dict) -> float:
|
||||||
return sim_matrix[q_field_name][c_field_name]
|
return sim_matrix[q_field_name][c_field_name]
|
||||||
|
|
||||||
|
|
||||||
SYMBOLIC_SIMS = [symbolic_sim]
|
|
||||||
|
|
||||||
"""
|
|
||||||
CHARACTER EDIT DISTANCE
|
|
||||||
"""
|
|
||||||
def edit_distance(word_1: str, word_2: str, to_same_case: bool = True) -> int:
|
|
||||||
|
|
||||||
if word_1 == word_2:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
if to_same_case:
|
|
||||||
word_1, word_2 = [word.upper() for word in (word_1, word_2)]
|
|
||||||
|
|
||||||
word_1, word_2 = list(word_1), list(word_2)
|
|
||||||
longer_word = word_1 if len(word_1) > len(word_2) else word_2
|
|
||||||
|
|
||||||
i, count = 0, 0
|
|
||||||
while i < len(longer_word):
|
|
||||||
|
|
||||||
# word_2 is longer -> add current char of word_2
|
|
||||||
if i >= len(word_1):
|
|
||||||
word_1.append(word_2[i])
|
|
||||||
count += 1
|
|
||||||
#continue
|
|
||||||
|
|
||||||
# word_1 is longer -> remove current char of word_1
|
|
||||||
if i >= len(word_2):
|
|
||||||
word_1.pop(i)
|
|
||||||
count += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# same char -> skip word
|
|
||||||
if word_1[i] == word_2[i]:
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# not in the beginning or the end
|
|
||||||
if i > 0 and i < len(word_1):
|
|
||||||
"""
|
|
||||||
previous char is same and current char of word_1 is same as next char of word_2
|
|
||||||
-> fill current char of word_2 between last and next char of word_1
|
|
||||||
e.g. word_1[i-1] = "M" ; word_1[i] = "R"
|
|
||||||
word_2[i-1] = "M" ; word_1[i] = "A" ; word_2[i+1] = "R"
|
|
||||||
"""
|
|
||||||
if word_1[i-1] == word_2[i-1] and word_1[i] == word_2[i+1]:
|
|
||||||
word_1.insert(i, word_2[i])
|
|
||||||
count += 1
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
if word_1[i] != word_2[i]:
|
|
||||||
word_1.pop(i)
|
|
||||||
count += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
return "".join(word_1), count
|
|
||||||
|
|
||||||
|
|
||||||
STRING_SIMS = [edit_distance]
|
|
Loading…
Reference in New Issue
Block a user