Initially added similarity matrices
This commit is contained in:
parent
0ecf066315
commit
aa9c025369
4
data/street_slope_sim.csv
Normal file
4
data/street_slope_sim.csv
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
type_street_slope, flat, ascending, decending
|
||||||
|
flat, 1.0, 0.5, 0.6
|
||||||
|
ascending, 0.5, 1.0, 0.3
|
||||||
|
decending, 0.6, 0.3, 1.0
|
|
3
data/street_type_sim.csv
Normal file
3
data/street_type_sim.csv
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
type_street, country_road (separated), autobahn
|
||||||
|
country_road (separated), 1.0, 0.5
|
||||||
|
autobahn, 0.5, 1.0
|
|
5
data/time_type_sim.csv
Normal file
5
data/time_type_sim.csv
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
type_time, night, dusk, day, dawn
|
||||||
|
night, 1.0, 0.5, 0.6, 0.8
|
||||||
|
dusk, 0.5, 1.0, 0.3, 0.5
|
||||||
|
day, 0.6, 0.3, 1.0, 0.6
|
||||||
|
dawn, 0.8, 0.5, 0.6, 1.0
|
|
5
data/vehicle_type_sim.csv
Normal file
5
data/vehicle_type_sim.csv
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
type_vehicle, car, motorcycle, sportscar, truck
|
||||||
|
car, 1.0, 0.5, 0.6, 0.8
|
||||||
|
motorcycle, 0.5, 1.0, 0.3, 0.5
|
||||||
|
sportscar, 0.6, 0.3, 1.0, 0.6
|
||||||
|
truck, 0.8, 0.5, 0.6, 1.0
|
|
5
data/weather_type_sim.csv
Normal file
5
data/weather_type_sim.csv
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
type_weather, dry, rain, fog, snow_ice
|
||||||
|
dry, 1.0, 0.5, 0.6, 0.8
|
||||||
|
rain, 0.5, 1.0, 0.3, 0.5
|
||||||
|
fog, 0.6, 0.3, 1.0, 0.6
|
||||||
|
snow_ice, 0.8, 0.5, 0.6, 1.0
|
|
131
explore.ipynb
131
explore.ipynb
@ -2,7 +2,7 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -13,9 +13,50 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def create_similarity_matrix(filename, key):\n",
|
||||||
|
" with open(filename) as file:\n",
|
||||||
|
" similarity_matrix = {}\n",
|
||||||
|
"\n",
|
||||||
|
" for line in csv.DictReader(file, skipinitialspace=True):\n",
|
||||||
|
" for k, v in line.items():\n",
|
||||||
|
" if k == key:\n",
|
||||||
|
" key_v = v\n",
|
||||||
|
" similarity_matrix[key_v] = {}\n",
|
||||||
|
" else:\n",
|
||||||
|
" similarity_matrix[key_v][k] = float(v)\n",
|
||||||
|
"\n",
|
||||||
|
" return similarity_matrix"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CaseBase(cases=499, fields=[('v', 'v_left', 'v_front', 'd_left', 'd_front', 'type_left', 'type_front', 'radius_curve(m)', 'slope_street', 'street_type', 'time', 'weather', 'type_vehicle', 'speed_limit(km/h)'), 'action'])\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[Case(problem={'v_left': 36.5, 'v_front': 23.0, 'd_left': -56.0, 'd_front': 45.0, 'type_left': 'sportscar', 'type_front': 'truck', 'radius_curve(m)': 2237.0, 'slope_street': 'flat', 'street_type': 'country_road (separated)', 'time': 'night', 'weather': 'dry', 'type_vehicle': 'car', 'speed_limit(km/h)': 100.0}, solution={'action': 'continue'}),\n",
|
||||||
|
" Case(problem={'v_left': 32.0, 'v_front': 28.0, 'd_left': -114.0, 'd_front': 44.0, 'type_left': 'motorcycle', 'type_front': 'car', 'radius_curve(m)': 3891.0, 'slope_street': 'ascending', 'street_type': 'autobahn', 'time': 'night', 'weather': 'dry', 'type_vehicle': 'truck', 'speed_limit(km/h)': 250.0}, solution={'action': 'continue'}),\n",
|
||||||
|
" Case(problem={'v_left': 43.0, 'v_front': 31.5, 'd_left': -98.0, 'd_front': 60.0, 'type_left': 'truck', 'type_front': 'car', 'radius_curve(m)': 1720.0, 'slope_street': 'flat', 'street_type': 'autobahn', 'time': 'dusk', 'weather': 'rain', 'type_vehicle': 'car', 'speed_limit(km/h)': 130.0}, solution={'action': 'continue'})]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"case_base = CaseBase.from_csv(\n",
|
"case_base = CaseBase.from_csv(\n",
|
||||||
" \"data/SIM_001.csv\",\n",
|
" \"data/SIM_001.csv\",\n",
|
||||||
@ -31,16 +72,78 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": [
|
||||||
|
"case_base.add_symbolic_sim(\n",
|
||||||
|
" field = \"type_left\",\n",
|
||||||
|
" similarity_matrix = create_similarity_matrix(\"data/vehicle_type_sim.csv\", \"type_vehicle\")\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"case_base.add_symbolic_sim(\n",
|
||||||
|
" field = \"type_front\",\n",
|
||||||
|
" similarity_matrix = create_similarity_matrix(\"data/vehicle_type_sim.csv\", \"type_vehicle\")\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"case_base.add_symbolic_sim(\n",
|
||||||
|
" field = \"type_vehicle\",\n",
|
||||||
|
" similarity_matrix = create_similarity_matrix(\"data/vehicle_type_sim.csv\", \"type_vehicle\")\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"case_base.add_symbolic_sim(\n",
|
||||||
|
" field = \"slope_street\",\n",
|
||||||
|
" similarity_matrix = create_similarity_matrix(\"data/street_slope_sim.csv\", \"type_street_slope\")\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"case_base.add_symbolic_sim(\n",
|
||||||
|
" field = \"time\",\n",
|
||||||
|
" similarity_matrix = create_similarity_matrix(\"data/time_type_sim.csv\", \"type_time\")\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"case_base.add_symbolic_sim(\n",
|
||||||
|
" field = \"weather\",\n",
|
||||||
|
" similarity_matrix = create_similarity_matrix(\"data/weather_type_sim.csv\", \"type_weather\")\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Your Query:\n",
|
||||||
|
" - v = 28\n",
|
||||||
|
" - v_left = 37\n",
|
||||||
|
" - v_front = 22.5\n",
|
||||||
|
" - d_left = -20\n",
|
||||||
|
" - d_front = 51\n",
|
||||||
|
" - radius_curve = 2000\n",
|
||||||
|
" - speed_limit = 200\n",
|
||||||
|
" - type_vehicle = motorcycle\n",
|
||||||
|
" - type_left = motorcycle\n",
|
||||||
|
" - type_front = motorcycle\n",
|
||||||
|
"\n",
|
||||||
|
"I recommend you this car:\n",
|
||||||
|
"Accelerated_lane_change\n",
|
||||||
|
"\n",
|
||||||
|
"Explanation:\n",
|
||||||
|
" - v_left = 37.0 (similarity: 1.00)\n",
|
||||||
|
" - v_front = 22.5 (similarity: 1.00)\n",
|
||||||
|
" - d_left = -17.0 (similarity: 0.25)\n",
|
||||||
|
" - d_front = 51.0 (similarity: 1.00)\n",
|
||||||
|
" - radius_curve(m) = 3020.0 (similarity: 0.00)\n",
|
||||||
|
" - speed_limit(km/h) = 120.0 (similarity: 0.01)\n",
|
||||||
|
" - type_vehicle = car (similarity: 0.50)\n",
|
||||||
|
" - type_left = truck (similarity: 0.50)\n",
|
||||||
|
" - type_front = truck (similarity: 0.50)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"query = Query.from_problems(\n",
|
"query = Query.from_problems(\n",
|
||||||
" v = 28,\n",
|
" v = 28,\n",
|
||||||
@ -48,15 +151,25 @@
|
|||||||
" v_front = 22.5,\n",
|
" v_front = 22.5,\n",
|
||||||
" d_left = -20,\n",
|
" d_left = -20,\n",
|
||||||
" d_front = 51,\n",
|
" d_front = 51,\n",
|
||||||
|
" radius_curve = 2000,\n",
|
||||||
|
" speed_limit = 200,\n",
|
||||||
|
" type_vehicle = \"motorcycle\",\n",
|
||||||
|
" type_left = \"motorcycle\",\n",
|
||||||
|
" type_front = \"motorcycle\",\n",
|
||||||
")\n",
|
")\n",
|
||||||
"# sim_funcs: manhattan_sim, euclid_sim\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
|
"# sim_funcs: manhattan_sim, euclid_sim\n",
|
||||||
"retrieved = case_base.retrieve(\n",
|
"retrieved = case_base.retrieve(\n",
|
||||||
" query,\n",
|
" query,\n",
|
||||||
" v_left = euclid_sim,\n",
|
" v_left = euclid_sim,\n",
|
||||||
" v_front = euclid_sim,\n",
|
" v_front = euclid_sim,\n",
|
||||||
" d_left = euclid_sim,\n",
|
" d_left = euclid_sim,\n",
|
||||||
" d_front = euclid_sim,\n",
|
" d_front = euclid_sim,\n",
|
||||||
|
" radius_curve = manhattan_sim,\n",
|
||||||
|
" speed_limit = manhattan_sim,\n",
|
||||||
|
" type_vehicle = symbolic_sim,\n",
|
||||||
|
" type_left = symbolic_sim,\n",
|
||||||
|
" type_front = symbolic_sim,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Your Query:\")\n",
|
"print(\"Your Query:\")\n",
|
||||||
@ -74,7 +187,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "venv",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -93,7 +206,7 @@
|
|||||||
"orig_nbformat": 4,
|
"orig_nbformat": 4,
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
"hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
|
"hash": "4c522f398908c053844cc48bcd755f88db468f52081f171c67c4d9e41e8d16a6"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
34
model.py
34
model.py
@ -200,22 +200,31 @@ class CaseBase:
|
|||||||
_sim_per_field = dict()
|
_sim_per_field = dict()
|
||||||
for field, sim_func in fields_and_sim_funcs.items():
|
for field, sim_func in fields_and_sim_funcs.items():
|
||||||
|
|
||||||
|
# Some columns contain special chars.
|
||||||
|
field_name = ""
|
||||||
|
if field == "radius_curve":
|
||||||
|
field_name = "radius_curve(m)"
|
||||||
|
elif field == "speed_limit":
|
||||||
|
field_name = "speed_limit(km/h)"
|
||||||
|
else:
|
||||||
|
field_name = field
|
||||||
|
|
||||||
if sim_func in sim.SYMBOLIC_SIMS:
|
if sim_func in sim.SYMBOLIC_SIMS:
|
||||||
field_sim = (
|
field_sim = (
|
||||||
field,
|
field_name,
|
||||||
sim_func(
|
sim_func(
|
||||||
query.problem[field],
|
query.problem[field],
|
||||||
case.problem[field],
|
case.problem[field_name],
|
||||||
self.get_symbolic_sim(field)
|
self.get_symbolic_sim(field)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif sim_func in sim.METRIC_SIMS:
|
elif sim_func in sim.METRIC_SIMS:
|
||||||
field_sim = (
|
field_sim = (
|
||||||
field,
|
field_name,
|
||||||
sim_func(
|
sim_func(
|
||||||
query.problem[field],
|
query.problem[field],
|
||||||
case.problem[field]
|
case.problem[field_name]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -256,24 +265,15 @@ class CaseBase:
|
|||||||
|
|
||||||
|
|
||||||
def add_symbolic_sim(self, field: str, similarity_matrix: dict):
|
def add_symbolic_sim(self, field: str, similarity_matrix: dict):
|
||||||
"""Add hardcoded similarities for symbolic values of `field`
|
|
||||||
|
|
||||||
structure of similarity_matrix:
|
|
||||||
{
|
|
||||||
"Audi": {"Audi": 1.0, "Citroen": 0.4, "Porsche": 0.9},
|
|
||||||
"Citroen": {"Audi": 0.4, "Citroen: 1.0, "Porsche": 0.2},
|
|
||||||
"Porsche": {"Audi": 0.7, "Citroen": 0.1, "Porsche": 1.0}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
|
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
|
||||||
raise ValueError(f"unknown field {field}")
|
raise ValueError(f"unknown field {field}")
|
||||||
|
|
||||||
self.__field_infos = {
|
if self.__field_infos is None:
|
||||||
field: {
|
self.__field_infos = {}
|
||||||
|
|
||||||
|
self.__field_infos[field] = {
|
||||||
"symbolic_sims": similarity_matrix
|
"symbolic_sims": similarity_matrix
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
def get_symbolic_sim(self, field: str) -> dict[str]:
|
def get_symbolic_sim(self, field: str) -> dict[str]:
|
||||||
return self.__field_infos[field]["symbolic_sims"]
|
return self.__field_infos[field]["symbolic_sims"]
|
Loading…
Reference in New Issue
Block a user