Skip to content

Commit b33ece9

Browse files
Merge pull request #3 from kushalbakshi/dev
Black formatting
2 parents 30a1b9c + 85e38f2 commit b33ece9

File tree

2 files changed

+82
-63
lines changed

2 files changed

+82
-63
lines changed

element_deeplabcut/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
"DLC_PROCESSED_DATA_DIR", dj.config["custom"].get("dlc_processed_data_dir", "")
1919
)
2020

21-
db_prefix = dj.config["custom"].get("database.prefix", "")
21+
db_prefix = dj.config["custom"].get("database.prefix", "")

notebooks/tutorial.ipynb

+81-62
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,12 @@
120120
"outputs": [],
121121
"source": [
122122
"import os\n",
123-
"if os.path.basename(os.getcwd())=='notebooks': os.chdir('..')\n",
124-
"assert os.path.basename(os.getcwd())=='element-deeplabcut', (\"Please move to the \"\n",
125-
" + \"element directory\")"
123+
"\n",
124+
"if os.path.basename(os.getcwd()) == \"notebooks\":\n",
125+
" os.chdir(\"..\")\n",
126+
"assert os.path.basename(os.getcwd()) == \"element-deeplabcut\", (\n",
127+
" \"Please move to the \" + \"element directory\"\n",
128+
")"
126129
]
127130
},
128131
{
@@ -201,7 +204,7 @@
201204
}
202205
],
203206
"source": [
204-
"from tutorial_pipeline import lab, subject, session, train, model "
207+
"from tutorial_pipeline import lab, subject, session, train, model"
205208
]
206209
},
207210
{
@@ -990,10 +993,10 @@
990993
],
991994
"source": [
992995
"(\n",
993-
" dj.Diagram(subject) \n",
994-
" + dj.Diagram(lab) \n",
995-
" + dj.Diagram(session) \n",
996-
" + dj.Diagram(model) \n",
996+
" dj.Diagram(subject)\n",
997+
" + dj.Diagram(lab)\n",
998+
" + dj.Diagram(session)\n",
999+
" + dj.Diagram(model)\n",
9971000
" + dj.Diagram(train)\n",
9981001
")"
9991002
]
@@ -1274,7 +1277,9 @@
12741277
"metadata": {},
12751278
"outputs": [],
12761279
"source": [
1277-
"config_file_rel = \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/config.yaml\""
1280+
"config_file_rel = (\n",
1281+
" \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/config.yaml\"\n",
1282+
")"
12781283
]
12791284
},
12801285
{
@@ -1358,11 +1363,13 @@
13581363
}
13591364
],
13601365
"source": [
1361-
"model.Model.insert_new_model(model_name='from_top_tracking_model_test',\n",
1362-
" dlc_config=config_file_rel,\n",
1363-
" shuffle=1,\n",
1364-
" trainingsetindex=0,\n",
1365-
" model_description='Model in example data: from_top_tracking model')"
1366+
"model.Model.insert_new_model(\n",
1367+
" model_name=\"from_top_tracking_model_test\",\n",
1368+
" dlc_config=config_file_rel,\n",
1369+
" shuffle=1,\n",
1370+
" trainingsetindex=0,\n",
1371+
" model_description=\"Model in example data: from_top_tracking model\",\n",
1372+
")"
13661373
]
13671374
},
13681375
{
@@ -1668,14 +1675,14 @@
16681675
"metadata": {},
16691676
"outputs": [],
16701677
"source": [
1671-
"#Definition of the dictionary named \"session_keys\"\n",
1678+
"# Definition of the dictionary named \"session_keys\"\n",
16721679
"session_keys = [\n",
16731680
" dict(subject=\"subject6\", session_datetime=\"2021-06-02 14:04:22\"),\n",
16741681
" dict(subject=\"subject6\", session_datetime=\"2021-06-03 14:43:10\"),\n",
16751682
"]\n",
16761683
"\n",
1677-
"#Insert this dictionary in the Session table\n",
1678-
"session.Session.insert(session_keys, skip_duplicates=True)\n"
1684+
"# Insert this dictionary in the Session table\n",
1685+
"session.Session.insert(session_keys, skip_duplicates=True)"
16791686
]
16801687
},
16811688
{
@@ -1791,10 +1798,14 @@
17911798
"metadata": {},
17921799
"outputs": [],
17931800
"source": [
1794-
"recording_key = {'subject': 'subject6',\n",
1795-
" 'session_datetime': '2021-06-02 14:04:22',\n",
1796-
" 'recording_id': '1'}\n",
1797-
"model.VideoRecording.insert1({**recording_key, 'device': 'Camera1'}, skip_duplicates=True)"
1801+
"recording_key = {\n",
1802+
" \"subject\": \"subject6\",\n",
1803+
" \"session_datetime\": \"2021-06-02 14:04:22\",\n",
1804+
" \"recording_id\": \"1\",\n",
1805+
"}\n",
1806+
"model.VideoRecording.insert1(\n",
1807+
" {**recording_key, \"device\": \"Camera1\"}, skip_duplicates=True\n",
1808+
")"
17981809
]
17991810
},
18001811
{
@@ -1810,12 +1821,14 @@
18101821
"metadata": {},
18111822
"outputs": [],
18121823
"source": [
1813-
"video_files = [\"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4\"]\n",
1824+
"video_files = [\n",
1825+
" \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4\"\n",
1826+
"]\n",
18141827
"\n",
1815-
"model.VideoRecording.File.insert({\n",
1816-
" **recording_key, \n",
1817-
" 'file_id': v_idx, \n",
1818-
" 'file_path': Path(f)} for v_idx, f in enumerate(video_files))"
1828+
"model.VideoRecording.File.insert(\n",
1829+
" {**recording_key, \"file_id\": v_idx, \"file_path\": Path(f)}\n",
1830+
" for v_idx, f in enumerate(video_files)\n",
1831+
")"
18191832
]
18201833
},
18211834
{
@@ -2054,7 +2067,7 @@
20542067
"metadata": {},
20552068
"outputs": [],
20562069
"source": [
2057-
"task_key = {**recording_key, 'model_name': 'from_top_tracking_model_test'}"
2070+
"task_key = {**recording_key, \"model_name\": \"from_top_tracking_model_test\"}"
20582071
]
20592072
},
20602073
{
@@ -2071,10 +2084,12 @@
20712084
"outputs": [],
20722085
"source": [
20732086
"model.PoseEstimationTask.insert1(\n",
2074-
" {**task_key,\n",
2075-
" 'task_mode': 'load',\n",
2076-
" 'pose_estimation_output_dir': './example_data/outbox/from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters'\n",
2077-
" })"
2087+
" {\n",
2088+
" **task_key,\n",
2089+
" \"task_mode\": \"load\",\n",
2090+
" \"pose_estimation_output_dir\": \"./example_data/outbox/from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters\",\n",
2091+
" }\n",
2092+
")"
20782093
]
20792094
},
20802095
{
@@ -2471,7 +2486,11 @@
24712486
"metadata": {},
24722487
"outputs": [],
24732488
"source": [
2474-
"df = (model.PoseEstimation.BodyPartPosition & task_key).fetch(format='frame').reset_index()"
2489+
"df = (\n",
2490+
" (model.PoseEstimation.BodyPartPosition & task_key)\n",
2491+
" .fetch(format=\"frame\")\n",
2492+
" .reset_index()\n",
2493+
")"
24752494
]
24762495
},
24772496
{
@@ -2836,7 +2855,7 @@
28362855
}
28372856
],
28382857
"source": [
2839-
"df = df.explode(['frame_index', 'x_pos', 'y_pos', 'likelihood']).reset_index()\n",
2858+
"df = df.explode([\"frame_index\", \"x_pos\", \"y_pos\", \"likelihood\"]).reset_index()\n",
28402859
"df"
28412860
]
28422861
},
@@ -2871,8 +2890,8 @@
28712890
"source": [
28722891
"import matplotlib.pyplot as plt\n",
28732892
"\n",
2874-
"head_data = df[df['body_part'] == 'head']\n",
2875-
"tail_data = df[df['body_part'] == 'tailbase']"
2893+
"head_data = df[df[\"body_part\"] == \"head\"]\n",
2894+
"tail_data = df[df[\"body_part\"] == \"tailbase\"]"
28762895
]
28772896
},
28782897
{
@@ -2892,18 +2911,18 @@
28922911
}
28932912
],
28942913
"source": [
2895-
"fig, axs = plt.subplots(2,1, figsize=(12, 4))\n",
2914+
"fig, axs = plt.subplots(2, 1, figsize=(12, 4))\n",
28962915
"\n",
2897-
"axs[0].set_title('x position - Head pose estimation')\n",
2898-
"axs[0].plot(head_data['x_pos'], label='x_pos')\n",
2899-
"axs[0].set_xlabel('time (frames)')\n",
2900-
"axs[0].set_ylabel('pos (pixels)')\n",
2916+
"axs[0].set_title(\"x position - Head pose estimation\")\n",
2917+
"axs[0].plot(head_data[\"x_pos\"], label=\"x_pos\")\n",
2918+
"axs[0].set_xlabel(\"time (frames)\")\n",
2919+
"axs[0].set_ylabel(\"pos (pixels)\")\n",
29012920
"axs[0].legend()\n",
29022921
"\n",
2903-
"axs[1].set_title('y position - Head pose estimation')\n",
2904-
"axs[1].plot(head_data['y_pos'], label='y_pos')\n",
2905-
"axs[1].set_xlabel('time (frames)')\n",
2906-
"axs[1].set_ylabel('pos (pixels)')\n",
2922+
"axs[1].set_title(\"y position - Head pose estimation\")\n",
2923+
"axs[1].plot(head_data[\"y_pos\"], label=\"y_pos\")\n",
2924+
"axs[1].set_xlabel(\"time (frames)\")\n",
2925+
"axs[1].set_ylabel(\"pos (pixels)\")\n",
29072926
"axs[1].legend()\n",
29082927
"\n",
29092928
"plt.tight_layout()\n",
@@ -2927,17 +2946,17 @@
29272946
}
29282947
],
29292948
"source": [
2930-
"fig, axs = plt.subplots(2,1, figsize=(12, 4))\n",
2931-
"axs[0].set_title('x position - Tailbase pose estimation')\n",
2932-
"axs[0].plot(head_data['x_pos'], label='x_pos',color='orange')\n",
2933-
"axs[0].set_xlabel('time (frames)')\n",
2934-
"axs[0].set_ylabel('pos (pixels)')\n",
2949+
"fig, axs = plt.subplots(2, 1, figsize=(12, 4))\n",
2950+
"axs[0].set_title(\"x position - Tailbase pose estimation\")\n",
2951+
"axs[0].plot(head_data[\"x_pos\"], label=\"x_pos\", color=\"orange\")\n",
2952+
"axs[0].set_xlabel(\"time (frames)\")\n",
2953+
"axs[0].set_ylabel(\"pos (pixels)\")\n",
29352954
"axs[0].legend()\n",
29362955
"\n",
2937-
"axs[1].set_title('y position - Tailbase pose estimation')\n",
2938-
"axs[1].plot(head_data['y_pos'], label='y_pos',color='orange')\n",
2939-
"axs[1].set_xlabel('time (frames)')\n",
2940-
"axs[1].set_ylabel('pos (pixels)')\n",
2956+
"axs[1].set_title(\"y position - Tailbase pose estimation\")\n",
2957+
"axs[1].plot(head_data[\"y_pos\"], label=\"y_pos\", color=\"orange\")\n",
2958+
"axs[1].set_xlabel(\"time (frames)\")\n",
2959+
"axs[1].set_ylabel(\"pos (pixels)\")\n",
29412960
"axs[1].legend()\n",
29422961
"\n",
29432962
"plt.tight_layout()\n",
@@ -2968,18 +2987,18 @@
29682987
}
29692988
],
29702989
"source": [
2971-
"fig, axs = plt.subplots(2,1, figsize=(6,10))\n",
2990+
"fig, axs = plt.subplots(2, 1, figsize=(6, 10))\n",
29722991
"\n",
2973-
"axs[0].set_title('Head pose estimation')\n",
2974-
"axs[0].plot(head_data['x_pos'], head_data['y_pos'],label='head',color='blue')\n",
2975-
"axs[0].set_xlabel('x position (pixels)')\n",
2976-
"axs[0].set_ylabel('y position (pixels)')\n",
2992+
"axs[0].set_title(\"Head pose estimation\")\n",
2993+
"axs[0].plot(head_data[\"x_pos\"], head_data[\"y_pos\"], label=\"head\", color=\"blue\")\n",
2994+
"axs[0].set_xlabel(\"x position (pixels)\")\n",
2995+
"axs[0].set_ylabel(\"y position (pixels)\")\n",
29772996
"axs[0].legend()\n",
29782997
"\n",
2979-
"axs[1].set_title('Tailbase pose estimation')\n",
2980-
"axs[1].plot(tail_data['x_pos'], tail_data['y_pos'], label='tailbase',color='orange')\n",
2981-
"axs[1].set_xlabel('x position (pixels)')\n",
2982-
"axs[1].set_ylabel('y position (pixels)')\n",
2998+
"axs[1].set_title(\"Tailbase pose estimation\")\n",
2999+
"axs[1].plot(tail_data[\"x_pos\"], tail_data[\"y_pos\"], label=\"tailbase\", color=\"orange\")\n",
3000+
"axs[1].set_xlabel(\"x position (pixels)\")\n",
3001+
"axs[1].set_ylabel(\"y position (pixels)\")\n",
29833002
"axs[1].legend()\n",
29843003
"\n",
29853004
"plt.tight_layout()\n",

0 commit comments

Comments
 (0)