Initially added similarity matrices
This commit is contained in:
parent
0ecf066315
commit
aa9c025369
4
data/street_slope_sim.csv
Normal file
4
data/street_slope_sim.csv
Normal file
@ -0,0 +1,4 @@
|
||||
type_street_slope, flat, ascending, decending
|
||||
flat, 1.0, 0.5, 0.6
|
||||
ascending, 0.5, 1.0, 0.3
|
||||
decending, 0.6, 0.3, 1.0
|
|
3
data/street_type_sim.csv
Normal file
3
data/street_type_sim.csv
Normal file
@ -0,0 +1,3 @@
|
||||
type_street, country_road (separated), autobahn
|
||||
country_road (separated), 1.0, 0.5
|
||||
autobahn, 0.5, 1.0
|
|
5
data/time_type_sim.csv
Normal file
5
data/time_type_sim.csv
Normal file
@ -0,0 +1,5 @@
|
||||
type_time, night, dusk, day, dawn
|
||||
night, 1.0, 0.5, 0.6, 0.8
|
||||
dusk, 0.5, 1.0, 0.3, 0.5
|
||||
day, 0.6, 0.3, 1.0, 0.6
|
||||
dawn, 0.8, 0.5, 0.6, 1.0
|
|
5
data/vehicle_type_sim.csv
Normal file
5
data/vehicle_type_sim.csv
Normal file
@ -0,0 +1,5 @@
|
||||
type_vehicle, car, motorcycle, sportscar, truck
|
||||
car, 1.0, 0.5, 0.6, 0.8
|
||||
motorcycle, 0.5, 1.0, 0.3, 0.5
|
||||
sportscar, 0.6, 0.3, 1.0, 0.6
|
||||
truck, 0.8, 0.5, 0.6, 1.0
|
|
5
data/weather_type_sim.csv
Normal file
5
data/weather_type_sim.csv
Normal file
@ -0,0 +1,5 @@
|
||||
type_weather, dry, rain, fog, snow_ice
|
||||
dry, 1.0, 0.5, 0.6, 0.8
|
||||
rain, 0.5, 1.0, 0.3, 0.5
|
||||
fog, 0.6, 0.3, 1.0, 0.6
|
||||
snow_ice, 0.8, 0.5, 0.6, 1.0
|
|
131
explore.ipynb
131
explore.ipynb
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -13,9 +13,50 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_similarity_matrix(filename, key):\n",
|
||||
" with open(filename) as file:\n",
|
||||
" similarity_matrix = {}\n",
|
||||
"\n",
|
||||
" for line in csv.DictReader(file, skipinitialspace=True):\n",
|
||||
" for k, v in line.items():\n",
|
||||
" if k == key:\n",
|
||||
" key_v = v\n",
|
||||
" similarity_matrix[key_v] = {}\n",
|
||||
" else:\n",
|
||||
" similarity_matrix[key_v][k] = float(v)\n",
|
||||
"\n",
|
||||
" return similarity_matrix"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CaseBase(cases=499, fields=[('v', 'v_left', 'v_front', 'd_left', 'd_front', 'type_left', 'type_front', 'radius_curve(m)', 'slope_street', 'street_type', 'time', 'weather', 'type_vehicle', 'speed_limit(km/h)'), 'action'])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Case(problem={'v_left': 36.5, 'v_front': 23.0, 'd_left': -56.0, 'd_front': 45.0, 'type_left': 'sportscar', 'type_front': 'truck', 'radius_curve(m)': 2237.0, 'slope_street': 'flat', 'street_type': 'country_road (separated)', 'time': 'night', 'weather': 'dry', 'type_vehicle': 'car', 'speed_limit(km/h)': 100.0}, solution={'action': 'continue'}),\n",
|
||||
" Case(problem={'v_left': 32.0, 'v_front': 28.0, 'd_left': -114.0, 'd_front': 44.0, 'type_left': 'motorcycle', 'type_front': 'car', 'radius_curve(m)': 3891.0, 'slope_street': 'ascending', 'street_type': 'autobahn', 'time': 'night', 'weather': 'dry', 'type_vehicle': 'truck', 'speed_limit(km/h)': 250.0}, solution={'action': 'continue'}),\n",
|
||||
" Case(problem={'v_left': 43.0, 'v_front': 31.5, 'd_left': -98.0, 'd_front': 60.0, 'type_left': 'truck', 'type_front': 'car', 'radius_curve(m)': 1720.0, 'slope_street': 'flat', 'street_type': 'autobahn', 'time': 'dusk', 'weather': 'rain', 'type_vehicle': 'car', 'speed_limit(km/h)': 130.0}, solution={'action': 'continue'})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"case_base = CaseBase.from_csv(\n",
|
||||
" \"data/SIM_001.csv\",\n",
|
||||
@ -31,16 +72,78 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"case_base.add_symbolic_sim(\n",
|
||||
" field = \"type_left\",\n",
|
||||
" similarity_matrix = create_similarity_matrix(\"data/vehicle_type_sim.csv\", \"type_vehicle\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"case_base.add_symbolic_sim(\n",
|
||||
" field = \"type_front\",\n",
|
||||
" similarity_matrix = create_similarity_matrix(\"data/vehicle_type_sim.csv\", \"type_vehicle\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"case_base.add_symbolic_sim(\n",
|
||||
" field = \"type_vehicle\",\n",
|
||||
" similarity_matrix = create_similarity_matrix(\"data/vehicle_type_sim.csv\", \"type_vehicle\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"case_base.add_symbolic_sim(\n",
|
||||
" field = \"slope_street\",\n",
|
||||
" similarity_matrix = create_similarity_matrix(\"data/street_slope_sim.csv\", \"type_street_slope\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"case_base.add_symbolic_sim(\n",
|
||||
" field = \"time\",\n",
|
||||
" similarity_matrix = create_similarity_matrix(\"data/time_type_sim.csv\", \"type_time\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"case_base.add_symbolic_sim(\n",
|
||||
" field = \"weather\",\n",
|
||||
" similarity_matrix = create_similarity_matrix(\"data/weather_type_sim.csv\", \"type_weather\")\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Your Query:\n",
|
||||
" - v = 28\n",
|
||||
" - v_left = 37\n",
|
||||
" - v_front = 22.5\n",
|
||||
" - d_left = -20\n",
|
||||
" - d_front = 51\n",
|
||||
" - radius_curve = 2000\n",
|
||||
" - speed_limit = 200\n",
|
||||
" - type_vehicle = motorcycle\n",
|
||||
" - type_left = motorcycle\n",
|
||||
" - type_front = motorcycle\n",
|
||||
"\n",
|
||||
"I recommend you this car:\n",
|
||||
"Accelerated_lane_change\n",
|
||||
"\n",
|
||||
"Explanation:\n",
|
||||
" - v_left = 37.0 (similarity: 1.00)\n",
|
||||
" - v_front = 22.5 (similarity: 1.00)\n",
|
||||
" - d_left = -17.0 (similarity: 0.25)\n",
|
||||
" - d_front = 51.0 (similarity: 1.00)\n",
|
||||
" - radius_curve(m) = 3020.0 (similarity: 0.00)\n",
|
||||
" - speed_limit(km/h) = 120.0 (similarity: 0.01)\n",
|
||||
" - type_vehicle = car (similarity: 0.50)\n",
|
||||
" - type_left = truck (similarity: 0.50)\n",
|
||||
" - type_front = truck (similarity: 0.50)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"query = Query.from_problems(\n",
|
||||
" v = 28,\n",
|
||||
@ -48,15 +151,25 @@
|
||||
" v_front = 22.5,\n",
|
||||
" d_left = -20,\n",
|
||||
" d_front = 51,\n",
|
||||
" radius_curve = 2000,\n",
|
||||
" speed_limit = 200,\n",
|
||||
" type_vehicle = \"motorcycle\",\n",
|
||||
" type_left = \"motorcycle\",\n",
|
||||
" type_front = \"motorcycle\",\n",
|
||||
")\n",
|
||||
"# sim_funcs: manhattan_sim, euclid_sim\n",
|
||||
"\n",
|
||||
"# sim_funcs: manhattan_sim, euclid_sim\n",
|
||||
"retrieved = case_base.retrieve(\n",
|
||||
" query,\n",
|
||||
" v_left = euclid_sim,\n",
|
||||
" v_front = euclid_sim,\n",
|
||||
" d_left = euclid_sim,\n",
|
||||
" d_front = euclid_sim,\n",
|
||||
" radius_curve = manhattan_sim,\n",
|
||||
" speed_limit = manhattan_sim,\n",
|
||||
" type_vehicle = symbolic_sim,\n",
|
||||
" type_left = symbolic_sim,\n",
|
||||
" type_front = symbolic_sim,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Your Query:\")\n",
|
||||
@ -74,7 +187,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -93,7 +206,7 @@
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
|
||||
"hash": "4c522f398908c053844cc48bcd755f88db468f52081f171c67c4d9e41e8d16a6"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
36
model.py
36
model.py
@ -200,22 +200,31 @@ class CaseBase:
|
||||
_sim_per_field = dict()
|
||||
for field, sim_func in fields_and_sim_funcs.items():
|
||||
|
||||
# Some columns contain special chars.
|
||||
field_name = ""
|
||||
if field == "radius_curve":
|
||||
field_name = "radius_curve(m)"
|
||||
elif field == "speed_limit":
|
||||
field_name = "speed_limit(km/h)"
|
||||
else:
|
||||
field_name = field
|
||||
|
||||
if sim_func in sim.SYMBOLIC_SIMS:
|
||||
field_sim = (
|
||||
field,
|
||||
field_name,
|
||||
sim_func(
|
||||
query.problem[field],
|
||||
case.problem[field],
|
||||
case.problem[field_name],
|
||||
self.get_symbolic_sim(field)
|
||||
)
|
||||
)
|
||||
|
||||
elif sim_func in sim.METRIC_SIMS:
|
||||
field_sim = (
|
||||
field,
|
||||
field_name,
|
||||
sim_func(
|
||||
query.problem[field],
|
||||
case.problem[field]
|
||||
case.problem[field_name]
|
||||
)
|
||||
)
|
||||
|
||||
@ -256,24 +265,15 @@ class CaseBase:
|
||||
|
||||
|
||||
def add_symbolic_sim(self, field: str, similarity_matrix: dict):
|
||||
"""Add hardcoded similarities for symbolic values of `field`
|
||||
|
||||
structure of similarity_matrix:
|
||||
{
|
||||
"Audi": {"Audi": 1.0, "Citroen": 0.4, "Porsche": 0.9},
|
||||
"Citroen": {"Audi": 0.4, "Citroen: 1.0, "Porsche": 0.2},
|
||||
"Porsche": {"Audi": 0.7, "Citroen": 0.1, "Porsche": 1.0}
|
||||
}
|
||||
"""
|
||||
|
||||
if field not in list(self.fields["problem"]) + list(self.fields["solution"]):
|
||||
raise ValueError(f"unknown field {field}")
|
||||
|
||||
self.__field_infos = {
|
||||
field: {
|
||||
if self.__field_infos is None:
|
||||
self.__field_infos = {}
|
||||
|
||||
self.__field_infos[field] = {
|
||||
"symbolic_sims": similarity_matrix
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_symbolic_sim(self, field: str) -> dict[str]:
|
||||
return self.__field_infos[field]["symbolic_sims"]
|
Loading…
Reference in New Issue
Block a user