|
120 | 120 | "outputs": [],
|
121 | 121 | "source": [
|
122 | 122 | "import os\n",
|
123 |
| - "if os.path.basename(os.getcwd())=='notebooks': os.chdir('..')\n", |
124 |
| - "assert os.path.basename(os.getcwd())=='element-deeplabcut', (\"Please move to the \"\n", |
125 |
| - " + \"element directory\")" |
| 123 | + "\n", |
| 124 | + "if os.path.basename(os.getcwd()) == \"notebooks\":\n", |
| 125 | + " os.chdir(\"..\")\n", |
| 126 | + "assert os.path.basename(os.getcwd()) == \"element-deeplabcut\", (\n", |
| 127 | + " \"Please move to the \" + \"element directory\"\n", |
| 128 | + ")" |
126 | 129 | ]
|
127 | 130 | },
|
128 | 131 | {
|
|
201 | 204 | }
|
202 | 205 | ],
|
203 | 206 | "source": [
|
204 |
| - "from tutorial_pipeline import lab, subject, session, train, model " |
| 207 | + "from tutorial_pipeline import lab, subject, session, train, model" |
205 | 208 | ]
|
206 | 209 | },
|
207 | 210 | {
|
|
990 | 993 | ],
|
991 | 994 | "source": [
|
992 | 995 | "(\n",
|
993 |
| - " dj.Diagram(subject) \n", |
994 |
| - " + dj.Diagram(lab) \n", |
995 |
| - " + dj.Diagram(session) \n", |
996 |
| - " + dj.Diagram(model) \n", |
| 996 | + " dj.Diagram(subject)\n", |
| 997 | + " + dj.Diagram(lab)\n", |
| 998 | + " + dj.Diagram(session)\n", |
| 999 | + " + dj.Diagram(model)\n", |
997 | 1000 | " + dj.Diagram(train)\n",
|
998 | 1001 | ")"
|
999 | 1002 | ]
|
|
1274 | 1277 | "metadata": {},
|
1275 | 1278 | "outputs": [],
|
1276 | 1279 | "source": [
|
1277 |
| - "config_file_rel = \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/config.yaml\"" |
| 1280 | + "config_file_rel = (\n", |
| 1281 | + " \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/config.yaml\"\n", |
| 1282 | + ")" |
1278 | 1283 | ]
|
1279 | 1284 | },
|
1280 | 1285 | {
|
|
1358 | 1363 | }
|
1359 | 1364 | ],
|
1360 | 1365 | "source": [
|
1361 |
| - "model.Model.insert_new_model(model_name='from_top_tracking_model_test',\n", |
1362 |
| - " dlc_config=config_file_rel,\n", |
1363 |
| - " shuffle=1,\n", |
1364 |
| - " trainingsetindex=0,\n", |
1365 |
| - " model_description='Model in example data: from_top_tracking model')" |
| 1366 | + "model.Model.insert_new_model(\n", |
| 1367 | + " model_name=\"from_top_tracking_model_test\",\n", |
| 1368 | + " dlc_config=config_file_rel,\n", |
| 1369 | + " shuffle=1,\n", |
| 1370 | + " trainingsetindex=0,\n", |
| 1371 | + " model_description=\"Model in example data: from_top_tracking model\",\n", |
| 1372 | + ")" |
1366 | 1373 | ]
|
1367 | 1374 | },
|
1368 | 1375 | {
|
|
1668 | 1675 | "metadata": {},
|
1669 | 1676 | "outputs": [],
|
1670 | 1677 | "source": [
|
1671 |
| - "#Definition of the dictionary named \"session_keys\"\n", |
| 1678 | + "# Definition of the dictionary named \"session_keys\"\n", |
1672 | 1679 | "session_keys = [\n",
|
1673 | 1680 | " dict(subject=\"subject6\", session_datetime=\"2021-06-02 14:04:22\"),\n",
|
1674 | 1681 | " dict(subject=\"subject6\", session_datetime=\"2021-06-03 14:43:10\"),\n",
|
1675 | 1682 | "]\n",
|
1676 | 1683 | "\n",
|
1677 |
| - "#Insert this dictionary in the Session table\n", |
1678 |
| - "session.Session.insert(session_keys, skip_duplicates=True)\n" |
| 1684 | + "# Insert this dictionary in the Session table\n", |
| 1685 | + "session.Session.insert(session_keys, skip_duplicates=True)" |
1679 | 1686 | ]
|
1680 | 1687 | },
|
1681 | 1688 | {
|
|
1791 | 1798 | "metadata": {},
|
1792 | 1799 | "outputs": [],
|
1793 | 1800 | "source": [
|
1794 |
| - "recording_key = {'subject': 'subject6',\n", |
1795 |
| - " 'session_datetime': '2021-06-02 14:04:22',\n", |
1796 |
| - " 'recording_id': '1'}\n", |
1797 |
| - "model.VideoRecording.insert1({**recording_key, 'device': 'Camera1'}, skip_duplicates=True)" |
| 1801 | + "recording_key = {\n", |
| 1802 | + " \"subject\": \"subject6\",\n", |
| 1803 | + " \"session_datetime\": \"2021-06-02 14:04:22\",\n", |
| 1804 | + " \"recording_id\": \"1\",\n", |
| 1805 | + "}\n", |
| 1806 | + "model.VideoRecording.insert1(\n", |
| 1807 | + " {**recording_key, \"device\": \"Camera1\"}, skip_duplicates=True\n", |
| 1808 | + ")" |
1798 | 1809 | ]
|
1799 | 1810 | },
|
1800 | 1811 | {
|
|
1810 | 1821 | "metadata": {},
|
1811 | 1822 | "outputs": [],
|
1812 | 1823 | "source": [
|
1813 |
| - "video_files = [\"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4\"]\n", |
| 1824 | + "video_files = [\n", |
| 1825 | + " \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4\"\n", |
| 1826 | + "]\n", |
1814 | 1827 | "\n",
|
1815 |
| - "model.VideoRecording.File.insert({\n", |
1816 |
| - " **recording_key, \n", |
1817 |
| - " 'file_id': v_idx, \n", |
1818 |
| - " 'file_path': Path(f)} for v_idx, f in enumerate(video_files))" |
| 1828 | + "model.VideoRecording.File.insert(\n", |
| 1829 | + " {**recording_key, \"file_id\": v_idx, \"file_path\": Path(f)}\n", |
| 1830 | + " for v_idx, f in enumerate(video_files)\n", |
| 1831 | + ")" |
1819 | 1832 | ]
|
1820 | 1833 | },
|
1821 | 1834 | {
|
|
2054 | 2067 | "metadata": {},
|
2055 | 2068 | "outputs": [],
|
2056 | 2069 | "source": [
|
2057 |
| - "task_key = {**recording_key, 'model_name': 'from_top_tracking_model_test'}" |
| 2070 | + "task_key = {**recording_key, \"model_name\": \"from_top_tracking_model_test\"}" |
2058 | 2071 | ]
|
2059 | 2072 | },
|
2060 | 2073 | {
|
|
2071 | 2084 | "outputs": [],
|
2072 | 2085 | "source": [
|
2073 | 2086 | "model.PoseEstimationTask.insert1(\n",
|
2074 |
| - " {**task_key,\n", |
2075 |
| - " 'task_mode': 'load',\n", |
2076 |
| - " 'pose_estimation_output_dir': './example_data/outbox/from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters'\n", |
2077 |
| - " })" |
| 2087 | + " {\n", |
| 2088 | + " **task_key,\n", |
| 2089 | + " \"task_mode\": \"load\",\n", |
| 2090 | + " \"pose_estimation_output_dir\": \"./example_data/outbox/from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters\",\n", |
| 2091 | + " }\n", |
| 2092 | + ")" |
2078 | 2093 | ]
|
2079 | 2094 | },
|
2080 | 2095 | {
|
|
2471 | 2486 | "metadata": {},
|
2472 | 2487 | "outputs": [],
|
2473 | 2488 | "source": [
|
2474 |
| - "df = (model.PoseEstimation.BodyPartPosition & task_key).fetch(format='frame').reset_index()" |
| 2489 | + "df = (\n", |
| 2490 | + " (model.PoseEstimation.BodyPartPosition & task_key)\n", |
| 2491 | + " .fetch(format=\"frame\")\n", |
| 2492 | + " .reset_index()\n", |
| 2493 | + ")" |
2475 | 2494 | ]
|
2476 | 2495 | },
|
2477 | 2496 | {
|
|
2836 | 2855 | }
|
2837 | 2856 | ],
|
2838 | 2857 | "source": [
|
2839 |
| - "df = df.explode(['frame_index', 'x_pos', 'y_pos', 'likelihood']).reset_index()\n", |
| 2858 | + "df = df.explode([\"frame_index\", \"x_pos\", \"y_pos\", \"likelihood\"]).reset_index()\n", |
2840 | 2859 | "df"
|
2841 | 2860 | ]
|
2842 | 2861 | },
|
|
2871 | 2890 | "source": [
|
2872 | 2891 | "import matplotlib.pyplot as plt\n",
|
2873 | 2892 | "\n",
|
2874 |
| - "head_data = df[df['body_part'] == 'head']\n", |
2875 |
| - "tail_data = df[df['body_part'] == 'tailbase']" |
| 2893 | + "head_data = df[df[\"body_part\"] == \"head\"]\n", |
| 2894 | + "tail_data = df[df[\"body_part\"] == \"tailbase\"]" |
2876 | 2895 | ]
|
2877 | 2896 | },
|
2878 | 2897 | {
|
|
2892 | 2911 | }
|
2893 | 2912 | ],
|
2894 | 2913 | "source": [
|
2895 |
| - "fig, axs = plt.subplots(2,1, figsize=(12, 4))\n", |
| 2914 | + "fig, axs = plt.subplots(2, 1, figsize=(12, 4))\n", |
2896 | 2915 | "\n",
|
2897 |
| - "axs[0].set_title('x position - Head pose estimation')\n", |
2898 |
| - "axs[0].plot(head_data['x_pos'], label='x_pos')\n", |
2899 |
| - "axs[0].set_xlabel('time (frames)')\n", |
2900 |
| - "axs[0].set_ylabel('pos (pixels)')\n", |
| 2916 | + "axs[0].set_title(\"x position - Head pose estimation\")\n", |
| 2917 | + "axs[0].plot(head_data[\"x_pos\"], label=\"x_pos\")\n", |
| 2918 | + "axs[0].set_xlabel(\"time (frames)\")\n", |
| 2919 | + "axs[0].set_ylabel(\"pos (pixels)\")\n", |
2901 | 2920 | "axs[0].legend()\n",
|
2902 | 2921 | "\n",
|
2903 |
| - "axs[1].set_title('y position - Head pose estimation')\n", |
2904 |
| - "axs[1].plot(head_data['y_pos'], label='y_pos')\n", |
2905 |
| - "axs[1].set_xlabel('time (frames)')\n", |
2906 |
| - "axs[1].set_ylabel('pos (pixels)')\n", |
| 2922 | + "axs[1].set_title(\"y position - Head pose estimation\")\n", |
| 2923 | + "axs[1].plot(head_data[\"y_pos\"], label=\"y_pos\")\n", |
| 2924 | + "axs[1].set_xlabel(\"time (frames)\")\n", |
| 2925 | + "axs[1].set_ylabel(\"pos (pixels)\")\n", |
2907 | 2926 | "axs[1].legend()\n",
|
2908 | 2927 | "\n",
|
2909 | 2928 | "plt.tight_layout()\n",
|
|
2927 | 2946 | }
|
2928 | 2947 | ],
|
2929 | 2948 | "source": [
|
2930 |
| - "fig, axs = plt.subplots(2,1, figsize=(12, 4))\n", |
2931 |
| - "axs[0].set_title('x position - Tailbase pose estimation')\n", |
2932 |
| - "axs[0].plot(head_data['x_pos'], label='x_pos',color='orange')\n", |
2933 |
| - "axs[0].set_xlabel('time (frames)')\n", |
2934 |
| - "axs[0].set_ylabel('pos (pixels)')\n", |
| 2949 | + "fig, axs = plt.subplots(2, 1, figsize=(12, 4))\n", |
| 2950 | + "axs[0].set_title(\"x position - Tailbase pose estimation\")\n", |
| 2951 | + "axs[0].plot(head_data[\"x_pos\"], label=\"x_pos\", color=\"orange\")\n", |
| 2952 | + "axs[0].set_xlabel(\"time (frames)\")\n", |
| 2953 | + "axs[0].set_ylabel(\"pos (pixels)\")\n", |
2935 | 2954 | "axs[0].legend()\n",
|
2936 | 2955 | "\n",
|
2937 |
| - "axs[1].set_title('y position - Tailbase pose estimation')\n", |
2938 |
| - "axs[1].plot(head_data['y_pos'], label='y_pos',color='orange')\n", |
2939 |
| - "axs[1].set_xlabel('time (frames)')\n", |
2940 |
| - "axs[1].set_ylabel('pos (pixels)')\n", |
| 2956 | + "axs[1].set_title(\"y position - Tailbase pose estimation\")\n", |
| 2957 | + "axs[1].plot(head_data[\"y_pos\"], label=\"y_pos\", color=\"orange\")\n", |
| 2958 | + "axs[1].set_xlabel(\"time (frames)\")\n", |
| 2959 | + "axs[1].set_ylabel(\"pos (pixels)\")\n", |
2941 | 2960 | "axs[1].legend()\n",
|
2942 | 2961 | "\n",
|
2943 | 2962 | "plt.tight_layout()\n",
|
|
2968 | 2987 | }
|
2969 | 2988 | ],
|
2970 | 2989 | "source": [
|
2971 |
| - "fig, axs = plt.subplots(2,1, figsize=(6,10))\n", |
| 2990 | + "fig, axs = plt.subplots(2, 1, figsize=(6, 10))\n", |
2972 | 2991 | "\n",
|
2973 |
| - "axs[0].set_title('Head pose estimation')\n", |
2974 |
| - "axs[0].plot(head_data['x_pos'], head_data['y_pos'],label='head',color='blue')\n", |
2975 |
| - "axs[0].set_xlabel('x position (pixels)')\n", |
2976 |
| - "axs[0].set_ylabel('y position (pixels)')\n", |
| 2992 | + "axs[0].set_title(\"Head pose estimation\")\n", |
| 2993 | + "axs[0].plot(head_data[\"x_pos\"], head_data[\"y_pos\"], label=\"head\", color=\"blue\")\n", |
| 2994 | + "axs[0].set_xlabel(\"x position (pixels)\")\n", |
| 2995 | + "axs[0].set_ylabel(\"y position (pixels)\")\n", |
2977 | 2996 | "axs[0].legend()\n",
|
2978 | 2997 | "\n",
|
2979 |
| - "axs[1].set_title('Tailbase pose estimation')\n", |
2980 |
| - "axs[1].plot(tail_data['x_pos'], tail_data['y_pos'], label='tailbase',color='orange')\n", |
2981 |
| - "axs[1].set_xlabel('x position (pixels)')\n", |
2982 |
| - "axs[1].set_ylabel('y position (pixels)')\n", |
| 2998 | + "axs[1].set_title(\"Tailbase pose estimation\")\n", |
| 2999 | + "axs[1].plot(tail_data[\"x_pos\"], tail_data[\"y_pos\"], label=\"tailbase\", color=\"orange\")\n", |
| 3000 | + "axs[1].set_xlabel(\"x position (pixels)\")\n", |
| 3001 | + "axs[1].set_ylabel(\"y position (pixels)\")\n", |
2983 | 3002 | "axs[1].legend()\n",
|
2984 | 3003 | "\n",
|
2985 | 3004 | "plt.tight_layout()\n",
|
|
0 commit comments