Initially added similarity matrices

This commit is contained in:
Administrator 2022-12-14 11:33:13 +01:00
parent 0ecf066315
commit aa9c025369
7 changed files with 162 additions and 27 deletions

View File

@ -0,0 +1,4 @@
type_street_slope, flat, ascending, decending
flat, 1.0, 0.5, 0.6
ascending, 0.5, 1.0, 0.3
decending, 0.6, 0.3, 1.0
1 type_street_slope flat ascending decending
2 flat 1.0 0.5 0.6
3 ascending 0.5 1.0 0.3
4 decending 0.6 0.3 1.0

3
data/street_type_sim.csv Normal file
View File

@ -0,0 +1,3 @@
type_street, country_road (separated), autobahn
country_road (separated), 1.0, 0.5
autobahn, 0.5, 1.0
1 type_street country_road (separated) autobahn
2 country_road (separated) 1.0 0.5
3 autobahn 0.5 1.0

5
data/time_type_sim.csv Normal file
View File

@ -0,0 +1,5 @@
type_time, night, dusk, day, dawn
night, 1.0, 0.5, 0.6, 0.8
dusk, 0.5, 1.0, 0.3, 0.5
day, 0.6, 0.3, 1.0, 0.6
dawn, 0.8, 0.5, 0.6, 1.0
1 type_time night dusk day dawn
2 night 1.0 0.5 0.6 0.8
3 dusk 0.5 1.0 0.3 0.5
4 day 0.6 0.3 1.0 0.6
5 dawn 0.8 0.5 0.6 1.0

View File

@ -0,0 +1,5 @@
type_vehicle, car, motorcycle, sportscar, truck
car, 1.0, 0.5, 0.6, 0.8
motorcycle, 0.5, 1.0, 0.3, 0.5
sportscar, 0.6, 0.3, 1.0, 0.6
truck, 0.8, 0.5, 0.6, 1.0
1 type_vehicle car motorcycle sportscar truck
2 car 1.0 0.5 0.6 0.8
3 motorcycle 0.5 1.0 0.3 0.5
4 sportscar 0.6 0.3 1.0 0.6
5 truck 0.8 0.5 0.6 1.0

View File

@ -0,0 +1,5 @@
type_weather, dry, rain, fog, snow_ice
dry, 1.0, 0.5, 0.6, 0.8
rain, 0.5, 1.0, 0.3, 0.5
fog, 0.6, 0.3, 1.0, 0.6
snow_ice, 0.8, 0.5, 0.6, 1.0
1 type_weather dry rain fog snow_ice
2 dry 1.0 0.5 0.6 0.8
3 rain 0.5 1.0 0.3 0.5
4 fog 0.6 0.3 1.0 0.6
5 snow_ice 0.8 0.5 0.6 1.0

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -13,9 +13,50 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"def create_similarity_matrix(filename, key):\n",
" with open(filename) as file:\n",
" similarity_matrix = {}\n",
"\n",
" for line in csv.DictReader(file, skipinitialspace=True):\n",
" for k, v in line.items():\n",
" if k == key:\n",
" key_v = v\n",
" similarity_matrix[key_v] = {}\n",
" else:\n",
" similarity_matrix[key_v][k] = float(v)\n",
"\n",
" return similarity_matrix"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CaseBase(cases=499, fields=[('v', 'v_left', 'v_front', 'd_left', 'd_front', 'type_left', 'type_front', 'radius_curve(m)', 'slope_street', 'street_type', 'time', 'weather', 'type_vehicle', 'speed_limit(km/h)'), 'action'])\n"
]
},
{
"data": {
"text/plain": [
"[Case(problem={'v_left': 36.5, 'v_front': 23.0, 'd_left': -56.0, 'd_front': 45.0, 'type_left': 'sportscar', 'type_front': 'truck', 'radius_curve(m)': 2237.0, 'slope_street': 'flat', 'street_type': 'country_road (separated)', 'time': 'night', 'weather': 'dry', 'type_vehicle': 'car', 'speed_limit(km/h)': 100.0}, solution={'action': 'continue'}),\n",
" Case(problem={'v_left': 32.0, 'v_front': 28.0, 'd_left': -114.0, 'd_front': 44.0, 'type_left': 'motorcycle', 'type_front': 'car', 'radius_curve(m)': 3891.0, 'slope_street': 'ascending', 'street_type': 'autobahn', 'time': 'night', 'weather': 'dry', 'type_vehicle': 'truck', 'speed_limit(km/h)': 250.0}, solution={'action': 'continue'}),\n",
" Case(problem={'v_left': 43.0, 'v_front': 31.5, 'd_left': -98.0, 'd_front': 60.0, 'type_left': 'truck', 'type_front': 'car', 'radius_curve(m)': 1720.0, 'slope_street': 'flat', 'street_type': 'autobahn', 'time': 'dusk', 'weather': 'rain', 'type_vehicle': 'car', 'speed_limit(km/h)': 130.0}, solution={'action': 'continue'})]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"case_base = CaseBase.from_csv(\n", "case_base = CaseBase.from_csv(\n",
" \"data/SIM_001.csv\",\n", " \"data/SIM_001.csv\",\n",
@ -31,16 +72,78 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"case_base.add_symbolic_sim(\n",
" field = \"type_left\",\n",
" similarity_matrix = create_similarity_matrix(\"data/vehicle_type_sim.csv\", \"type_vehicle\")\n",
")\n",
"\n",
"case_base.add_symbolic_sim(\n",
" field = \"type_front\",\n",
" similarity_matrix = create_similarity_matrix(\"data/vehicle_type_sim.csv\", \"type_vehicle\")\n",
")\n",
"\n",
"case_base.add_symbolic_sim(\n",
" field = \"type_vehicle\",\n",
" similarity_matrix = create_similarity_matrix(\"data/vehicle_type_sim.csv\", \"type_vehicle\")\n",
")\n",
"\n",
"case_base.add_symbolic_sim(\n",
" field = \"slope_street\",\n",
" similarity_matrix = create_similarity_matrix(\"data/street_slope_sim.csv\", \"type_street_slope\")\n",
")\n",
"\n",
"case_base.add_symbolic_sim(\n",
" field = \"time\",\n",
" similarity_matrix = create_similarity_matrix(\"data/time_type_sim.csv\", \"type_time\")\n",
")\n",
"\n",
"case_base.add_symbolic_sim(\n",
" field = \"weather\",\n",
" similarity_matrix = create_similarity_matrix(\"data/weather_type_sim.csv\", \"type_weather\")\n",
")\n"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Your Query:\n",
" - v = 28\n",
" - v_left = 37\n",
" - v_front = 22.5\n",
" - d_left = -20\n",
" - d_front = 51\n",
" - radius_curve = 2000\n",
" - speed_limit = 200\n",
" - type_vehicle = motorcycle\n",
" - type_left = motorcycle\n",
" - type_front = motorcycle\n",
"\n",
"I recommend you this car:\n",
"Accelerated_lane_change\n",
"\n",
"Explanation:\n",
" - v_left = 37.0 (similarity: 1.00)\n",
" - v_front = 22.5 (similarity: 1.00)\n",
" - d_left = -17.0 (similarity: 0.25)\n",
" - d_front = 51.0 (similarity: 1.00)\n",
" - radius_curve(m) = 3020.0 (similarity: 0.00)\n",
" - speed_limit(km/h) = 120.0 (similarity: 0.01)\n",
" - type_vehicle = car (similarity: 0.50)\n",
" - type_left = truck (similarity: 0.50)\n",
" - type_front = truck (similarity: 0.50)\n"
]
}
],
"source": [ "source": [
"query = Query.from_problems(\n", "query = Query.from_problems(\n",
" v = 28,\n", " v = 28,\n",
@ -48,15 +151,25 @@
" v_front = 22.5,\n", " v_front = 22.5,\n",
" d_left = -20,\n", " d_left = -20,\n",
" d_front = 51,\n", " d_front = 51,\n",
" radius_curve = 2000,\n",
" speed_limit = 200,\n",
" type_vehicle = \"motorcycle\",\n",
" type_left = \"motorcycle\",\n",
" type_front = \"motorcycle\",\n",
")\n", ")\n",
"# sim_funcs: manhattan_sim, euclid_sim\n",
"\n", "\n",
"# sim_funcs: manhattan_sim, euclid_sim\n",
"retrieved = case_base.retrieve(\n", "retrieved = case_base.retrieve(\n",
" query,\n", " query,\n",
" v_left = euclid_sim,\n", " v_left = euclid_sim,\n",
" v_front = euclid_sim,\n", " v_front = euclid_sim,\n",
" d_left = euclid_sim,\n", " d_left = euclid_sim,\n",
" d_front = euclid_sim,\n", " d_front = euclid_sim,\n",
" radius_curve = manhattan_sim,\n",
" speed_limit = manhattan_sim,\n",
" type_vehicle = symbolic_sim,\n",
" type_left = symbolic_sim,\n",
" type_front = symbolic_sim,\n",
")\n", ")\n",
"\n", "\n",
"print(\"Your Query:\")\n", "print(\"Your Query:\")\n",
@ -74,7 +187,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "venv",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -93,7 +206,7 @@
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" "hash": "4c522f398908c053844cc48bcd755f88db468f52081f171c67c4d9e41e8d16a6"
} }
} }
}, },

View File

@ -200,22 +200,31 @@ class CaseBase:
_sim_per_field = dict() _sim_per_field = dict()
for field, sim_func in fields_and_sim_funcs.items(): for field, sim_func in fields_and_sim_funcs.items():
# Some columns contain special chars.
field_name = ""
if field == "radius_curve":
field_name = "radius_curve(m)"
elif field == "speed_limit":
field_name = "speed_limit(km/h)"
else:
field_name = field
if sim_func in sim.SYMBOLIC_SIMS: if sim_func in sim.SYMBOLIC_SIMS:
field_sim = ( field_sim = (
field, field_name,
sim_func( sim_func(
query.problem[field], query.problem[field],
case.problem[field], case.problem[field_name],
self.get_symbolic_sim(field) self.get_symbolic_sim(field)
) )
) )
elif sim_func in sim.METRIC_SIMS: elif sim_func in sim.METRIC_SIMS:
field_sim = ( field_sim = (
field, field_name,
sim_func( sim_func(
query.problem[field], query.problem[field],
case.problem[field] case.problem[field_name]
) )
) )
@ -256,24 +265,15 @@ class CaseBase:
def add_symbolic_sim(self, field: str, similarity_matrix: dict): def add_symbolic_sim(self, field: str, similarity_matrix: dict):
"""Add hardcoded similarities for symbolic values of `field`
structure of similarity_matrix:
{
"Audi": {"Audi": 1.0, "Citroen": 0.4, "Porsche": 0.9},
"Citroen": {"Audi": 0.4, "Citroen: 1.0, "Porsche": 0.2},
"Porsche": {"Audi": 0.7, "Citroen": 0.1, "Porsche": 1.0}
}
"""
if field not in list(self.fields["problem"]) + list(self.fields["solution"]): if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
raise ValueError(f"unknown field {field}") raise ValueError(f"unknown field {field}")
self.__field_infos = { if self.__field_infos is None:
field: { self.__field_infos = {}
self.__field_infos[field] = {
"symbolic_sims": similarity_matrix "symbolic_sims": similarity_matrix
} }
}
def get_symbolic_sim(self, field: str) -> dict[str]: def get_symbolic_sim(self, field: str) -> dict[str]:
return self.__field_infos[field]["symbolic_sims"] return self.__field_infos[field]["symbolic_sims"]