@@ -241,13 +241,17 @@ class ModelTraining(dj.Computed):
241
241
# https://github.com/DeepLabCut/DeepLabCut/issues/70
242
242
243
243
def make (self , key ):
244
- from deeplabcut import train_network # isort:skip
244
+ from deeplabcut import train_network # isort:skip
245
+
245
246
try :
246
- from deeplabcut .utils .auxiliaryfunctions import get_model_folder # isort:skip
247
+ from deeplabcut .utils .auxiliaryfunctions import (
248
+ get_model_folder ,
249
+ edit_config ,
250
+ ) # isort:skip
247
251
except ImportError :
248
252
from deeplabcut .utils .auxiliaryfunctions import (
249
- GetModelFolder as get_model_folder
250
- ) # isort:skip
253
+ GetModelFolder as get_model_folder ,
254
+ ) # isort:skip
251
255
252
256
"""Launch training for each train.TrainingTask training_id via `.populate()`."""
253
257
project_path , model_prefix = (TrainingTask & key ).fetch1 (
@@ -275,11 +279,26 @@ def make(self, key):
275
279
# Write dlc config file to base project folder
276
280
dlc_cfg_filepath = dlc_reader .save_yaml (project_path , dlc_config )
277
281
282
+ # ---- Update the project path in the DLC pose configuration (yaml) files ----
283
+ model_folder = get_model_folder (
284
+ trainFraction = dlc_config ["train_fraction" ],
285
+ shuffle = dlc_config ["shuffle" ],
286
+ cfg = dlc_config ,
287
+ modelprefix = dlc_config ["modelprefix" ],
288
+ )
289
+ model_train_folder = project_path / model_folder / "train"
290
+
291
+ edit_config (
292
+ model_train_folder / "pose_cfg.yaml" ,
293
+ {"project_path" : project_path .as_posix ()},
294
+ )
295
+
278
296
# ---- Trigger DLC model training job ----
279
297
train_network_input_args = list (inspect .signature (train_network ).parameters )
280
298
train_network_kwargs = {
281
- k : int (v ) if k in ("shuffle" , "trainingsetindex" , "maxiters" ) else v
282
- for k , v in dlc_config .items () if k in train_network_input_args
299
+ k : int (v ) if k in ("shuffle" , "trainingsetindex" , "maxiters" ) else v
300
+ for k , v in dlc_config .items ()
301
+ if k in train_network_input_args
283
302
}
284
303
for k in ["shuffle" , "trainingsetindex" , "maxiters" ]:
285
304
train_network_kwargs [k ] = int (train_network_kwargs [k ])
@@ -289,18 +308,7 @@ def make(self, key):
289
308
except KeyboardInterrupt : # Instructions indicate to train until interrupt
290
309
print ("DLC training stopped via Keyboard Interrupt" )
291
310
292
- snapshots = list (
293
- (
294
- project_path
295
- / get_model_folder (
296
- trainFraction = dlc_config ["train_fraction" ],
297
- shuffle = dlc_config ["shuffle" ],
298
- cfg = dlc_config ,
299
- modelprefix = dlc_config ["modelprefix" ],
300
- )
301
- / "train"
302
- ).glob ("*index*" )
303
- )
311
+ snapshots = list (model_train_folder .glob ("*index*" ))
304
312
max_modified_time = 0
305
313
# DLC goes by snapshot magnitude when judging 'latest' for evaluation
306
314
# Here, we mean most recently generated
0 commit comments