Skip to content

Commit 217b858

Browse files
authored
Merge pull request #39 from ttngu207/main
chore: minor code cleanup
2 parents 5465016 + 475ca27 commit 217b858

File tree

1 file changed

+23
-41
lines changed

1 file changed

+23
-41
lines changed

element_facemap/facemap_inference.py

+23-41
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,10 @@ class BodyPartPosition(dj.Part):
324324
"""
325325

326326
def make(self, key):
327-
""".populate() method will launch training for each FacemapInferenceTask"""
327+
"""
328+
Calls facemap.pose.Pose to run pose estimation on the video files using the specified model.
329+
Video files and model are specified in the FacemapInferenceTask table.
330+
"""
328331
# ID model and directories
329332
task_mode, output_dir = (FacemapInferenceTask & key).fetch1(
330333
"task_mode", "facemap_inference_output_dir"
@@ -353,27 +356,9 @@ def make(self, key):
353356
full_metadata_path = output_dir / f"{vid_name}_FacemapPose_metadata.pkl"
354357

355358
# Load or Trigger Facemap Pose Estimation Inference
356-
if (
359+
if task_mode == "trigger" and not (
357360
facemap_result_path.exists() & full_metadata_path.exists()
358-
) or task_mode == "load": # Load results and do not rerun processing
359-
(
360-
body_part_position_entry,
361-
inference_duration,
362-
total_frame_count,
363-
creation_time,
364-
) = _load_facemap_results(key, facemap_result_path, full_metadata_path)
365-
self.insert1(
366-
{
367-
**key,
368-
"inference_completion_time": creation_time,
369-
"inference_run_duration": inference_duration,
370-
"total_frame_count": total_frame_count,
371-
}
372-
)
373-
self.BodyPartPosition.insert(body_part_position_entry)
374-
return
375-
376-
elif task_mode == "trigger":
361+
):
377362
from facemap.pose import pose as facemap_pose, model_loader
378363

379364
bbox = (FacemapInferenceTask & key).fetch1("bbox") or []
@@ -382,9 +367,10 @@ def make(self, key):
382367
facemap_model_name = (
383368
FacemapModel.File & f'model_id="{key["model_id"]}"'
384369
).fetch1("model_file")
385-
386370
facemap_model_path = Path.cwd() / facemap_model_name
371+
# copy this model file to the facemap model root directory (~/.facemap/models/)
387372
models_root_dir = model_loader.get_models_dir()
373+
shutil.copy(facemap_model_path, models_root_dir)
388374

389375
# Create Symbolic Links to raw video data files from outbox directory
390376
video_symlinks = []
@@ -395,9 +381,6 @@ def make(self, key):
395381
video_symlink.symlink_to(video_file)
396382
video_symlinks.append(video_symlink.as_posix())
397383

398-
# copy this model file to the facemap model root directory (~/.facemap/models/)
399-
shutil.copy(facemap_model_path, models_root_dir)
400-
401384
# Instantiate Pose object, with filenames specified as video files, and bounding specified in params
402385
# Assumes GUI to be none as we are running CLI implementation
403386
pose = facemap_pose.Pose(
@@ -408,21 +391,21 @@ def make(self, key):
408391
)
409392
pose.run()
410393

411-
(
412-
body_part_position_entry,
413-
inference_duration,
414-
total_frame_count,
415-
creation_time,
416-
) = _load_facemap_results(key, facemap_result_path, full_metadata_path)
417-
self.insert1(
418-
{
419-
**key,
420-
"inference_completion_time": creation_time,
421-
"inference_run_duration": inference_duration,
422-
"total_frame_count": total_frame_count,
423-
}
424-
)
425-
self.BodyPartPosition.insert(body_part_position_entry)
394+
(
395+
body_part_position_entry,
396+
inference_duration,
397+
total_frame_count,
398+
creation_time,
399+
) = _load_facemap_results(key, facemap_result_path, full_metadata_path)
400+
self.insert1(
401+
{
402+
**key,
403+
"inference_completion_time": creation_time,
404+
"inference_run_duration": inference_duration,
405+
"total_frame_count": total_frame_count,
406+
}
407+
)
408+
self.BodyPartPosition.insert(body_part_position_entry)
426409

427410
@classmethod
428411
def get_trajectory(cls, key: dict, body_parts: list = "all") -> pd.DataFrame:
@@ -468,7 +451,6 @@ def get_trajectory(cls, key: dict, body_parts: list = "all") -> pd.DataFrame:
468451

469452
def _load_facemap_results(key, facemap_result_path, full_metadata_path):
470453
"""Load facemap results from h5 and metadata files."""
471-
472454
from facemap import utils
473455

474456
with open(full_metadata_path, "rb") as f:

0 commit comments

Comments
 (0)