Skip to content

Commit 78ac980

Browse files
authored
Merge pull request #87 from MilagrosMarin/train_bug
Fix project path in the pose config file
2 parents 4046d89 + 81eb296 commit 78ac980

File tree

3 files changed

+32
-19
lines changed

3 files changed

+32
-19
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
44
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
55

6+
## [0.2.7] - 2023-08-04
7+
8+
+ Fix - Update the project path in the pose config file to train the model
9+
610
## [0.2.6] - 2023-05-22
711

812
+ Add - DeepLabCut, NWB, and DANDI citations
@@ -68,6 +72,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
6872
graciously provided by the Mathis Lab.
6973
+ Add - Support for 2d single-animal models
7074

75+
[0.2.7]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.7
7176
[0.2.6]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.6
7277
[0.2.5]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.5
7378
[0.2.4]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.4

element_deeplabcut/train.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,17 @@ class ModelTraining(dj.Computed):
241241
# https://github.com/DeepLabCut/DeepLabCut/issues/70
242242

243243
def make(self, key):
244-
from deeplabcut import train_network # isort:skip
244+
from deeplabcut import train_network # isort:skip
245+
245246
try:
246-
from deeplabcut.utils.auxiliaryfunctions import get_model_folder # isort:skip
247+
from deeplabcut.utils.auxiliaryfunctions import (
248+
get_model_folder,
249+
edit_config,
250+
) # isort:skip
247251
except ImportError:
248252
from deeplabcut.utils.auxiliaryfunctions import (
249-
GetModelFolder as get_model_folder
250-
) # isort:skip
253+
GetModelFolder as get_model_folder,
254+
) # isort:skip
251255

252256
"""Launch training for each train.TrainingTask training_id via `.populate()`."""
253257
project_path, model_prefix = (TrainingTask & key).fetch1(
@@ -275,11 +279,26 @@ def make(self, key):
275279
# Write dlc config file to base project folder
276280
dlc_cfg_filepath = dlc_reader.save_yaml(project_path, dlc_config)
277281

282+
# ---- Update the project path in the DLC pose configuration (yaml) files ----
283+
model_folder = get_model_folder(
284+
trainFraction=dlc_config["train_fraction"],
285+
shuffle=dlc_config["shuffle"],
286+
cfg=dlc_config,
287+
modelprefix=dlc_config["modelprefix"],
288+
)
289+
model_train_folder = project_path / model_folder / "train"
290+
291+
edit_config(
292+
model_train_folder / "pose_cfg.yaml",
293+
{"project_path": project_path.as_posix()},
294+
)
295+
278296
# ---- Trigger DLC model training job ----
279297
train_network_input_args = list(inspect.signature(train_network).parameters)
280298
train_network_kwargs = {
281-
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
282-
for k, v in dlc_config.items() if k in train_network_input_args
299+
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
300+
for k, v in dlc_config.items()
301+
if k in train_network_input_args
283302
}
284303
for k in ["shuffle", "trainingsetindex", "maxiters"]:
285304
train_network_kwargs[k] = int(train_network_kwargs[k])
@@ -289,18 +308,7 @@ def make(self, key):
289308
except KeyboardInterrupt: # Instructions indicate to train until interrupt
290309
print("DLC training stopped via Keyboard Interrupt")
291310

292-
snapshots = list(
293-
(
294-
project_path
295-
/ get_model_folder(
296-
trainFraction=dlc_config["train_fraction"],
297-
shuffle=dlc_config["shuffle"],
298-
cfg=dlc_config,
299-
modelprefix=dlc_config["modelprefix"],
300-
)
301-
/ "train"
302-
).glob("*index*")
303-
)
311+
snapshots = list(model_train_folder.glob("*index*"))
304312
max_modified_time = 0
305313
# DLC goes by snapshot magnitude when judging 'latest' for evaluation
306314
# Here, we mean most recently generated

element_deeplabcut/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""
22
Package metadata
33
"""
4-
__version__ = "0.2.6"
4+
__version__ = "0.2.7"

0 commit comments

Comments
 (0)