Skip to content

Commit

Permalink
fixed args
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsoleg committed Oct 6, 2024
1 parent 3a3262b commit 919cb07
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
12 changes: 6 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def montecarlo(model, loader):
pred, _ = model(batch_copy)
inference_output["pred"].append(pred.detach().to("cpu"))
inference_output["true"].append(pseudo_true.detach().to("cpu"))
inference_output["iou"].append(compute_3D_IoU(_pred, _true).detach().to("cpu"))
inference_output["mae"].append(compute_loss(_pred, _true)[0].detach().to("cpu"))
inference_output["iou"].append(compute_3D_IoU(pred, pseudo_true).detach().to("cpu"))
inference_output["mae"].append(compute_loss(pred, pseudo_true)[0].detach().to("cpu"))
pickle.dump(inference_output, open(cfg.inference_output.replace(".pkl", "_montecarlo_"+str(i)+".pkl"), "wb"))
iou_montecarlo+=inference_output["iou"]
mae_montecarlo+=inference_output["mae"]
Expand Down Expand Up @@ -85,7 +85,7 @@ def montecarlo(model, loader):
parser.add_argument("--wandb_entity", type=str, default="aiquaneuro", help="Name of the wandb entity")
parser.add_argument("--loss", type=str, default="MAE", help="Loss function")
parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--warmup", type=float, default=0.01, help="Warmup")
parser.add_argument('--model', type=str, default="CartNet", help="Model Name")
parser.add_argument("--max_neighbours", type=int, default=25, help="Max neighbours (only for iComformer/eComformer)")
Expand All @@ -104,7 +104,7 @@ def montecarlo(model, loader):

set_cfg(cfg)

args, _ = parser.parse_known_args()
args = parser.parse_args()
cfg.seed = args.seed
cfg.name = args.name
cfg.run_dir = "results/"+cfg.name+"/"+str(cfg.seed)
Expand All @@ -118,7 +118,7 @@ def montecarlo(model, loader):
cfg.wandb_entity = args.wandb_entity
cfg.loss = args.loss
cfg.optim.max_epoch = args.epochs
cfg.learning_rate = args.learning_rate
cfg.lr = args.lr
cfg.warmup = args.warmup
cfg.model = args.model
cfg.max_neighbours = None if cfg.model== "CartNet" else args.max_neighbours
Expand Down Expand Up @@ -152,7 +152,7 @@ def montecarlo(model, loader):
cfg.params_count = params_count(model)
logging.info(f"Number of parameters: {cfg.params_count}")

optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

loggers = create_logger()

Expand Down
2 changes: 1 addition & 1 deletion train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def train(model, loaders, optimizer, loggers):
perf = [[] for _ in range(num_splits-1)]
ckpt_dir = osp.join(cfg.run_dir,"ckpt/")

scheduler = OneCycleLR(optimizer, max_lr=cfg.learning_rate, total_steps=cfg.optim.max_epoch *len(loaders[0])//cfg.batch_accumulation + cfg.optim.max_epoch , pct_start=cfg.warmup)
scheduler = OneCycleLR(optimizer, max_lr=cfg.lr, total_steps=cfg.optim.max_epoch * len(loaders[0]) // cfg.batch_accumulation + cfg.optim.max_epoch , pct_start=cfg.warmup)

for cur_epoch in range(cfg.optim.max_epoch):
start_time = time.perf_counter()
Expand Down
16 changes: 8 additions & 8 deletions train_scripts/train_cartnet_adp.sh
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@

CUDA_VISIBLE_DEVICES=0 python main.py --seed 0 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &
CUDA_VISIBLE_DEVICES=1 python main.py --seed 1 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=2 python main.py --seed 2 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=3 python main.py --seed 3 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=4 python main.py --seed 4 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=5 python main.py --seed 5 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=6 python main.py --seed 6 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=7 python main.py --seed 7 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

wait
Expand Down
8 changes: 4 additions & 4 deletions train_scripts/train_icomformer_adp.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
CUDA_VISIBLE_DEVICES=0 python main.py --seed 0 --name "icomformer" --model "icomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 &
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 &
CUDA_VISIBLE_DEVICES=4 python main.py --seed 1 --name "icomformer" --model "icomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 &
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 &

CUDA_VISIBLE_DEVICES=2 python main.py --seed 2 --name "icomformer" --model "icomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 &
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 &

CUDA_VISIBLE_DEVICES=3 python main.py --seed 3 --name "icomformer" --model "icomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 &
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 &

0 comments on commit 919cb07

Please sign in to comment.