Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Sibozhu committed Aug 30, 2020
1 parent af86648 commit 61f7bcd
Showing 1 changed file with 20 additions and 33 deletions.
53 changes: 20 additions & 33 deletions RektNet/keypoints_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"study_name=\"tutorial\"\n",
"\n",
"current_month = datetime.now().strftime('%B').lower()\n",
Expand All @@ -174,28 +173,25 @@
"sys.stdout = Logger(save_file_name + '.log')\n",
"sys.stderr = Logger(save_file_name + '.error')\n",
"\n",
"INPUT_SIZE = (80, 80)\n",
"KPT_KEYS = [\"top\", \"mid_L_top\", \"mid_R_top\", \"mid_L_bot\", \"mid_R_bot\", \"bot_L\", \"bot_R\"]\n",
"\n",
"# Training related config\n",
"INPUT_SIZE = (80, 80) # dataset size\n",
"KPT_KEYS = [\"top\", \"mid_L_top\", \"mid_R_top\", \"mid_L_bot\", \"mid_R_bot\", \"bot_L\", \"bot_R\"] # set up geometry loss keys\n",
"intervals = int(2) # for normal training, set it to 4\n",
"val_split = float(0.15)\n",
"\n",
"batch_size= int(32)\n",
"val_split = float(0.15) # training validation split ratio\n",
"batch_size= int(8)\n",
"num_epochs= int(4) # for normal training, set it to 1024\n",
"\n",
"# Load the train data.\n",
"train_csv = \"dataset/rektnet_label.csv\"\n",
"dataset_path = \"dataset/RektNet_Dataset/\"\n",
"vis_dataloader = False\n",
"vis_dataloader = False # visualize dataset\n",
"save_checkpoints = True\n",
"save_checkpoints=True\n",
"evaluate_mode=False\n",
"\n",
"# Training related hyperparameter\n",
"lr = 1e-1\n",
"lr_gamma = 0.999\n",
"geo_loss = True\n",
"geo_loss_gamma_vert = 0\n",
"geo_loss_gamma_horz = 0\n",
"loss_type = \"l1_softargmax\"\n",
"loss_type = \"l1_softargmax\" # loss function type: l2_softargmax|l2_heatmap|l1_softargmax\n",
"best_val_loss = float('inf')\n",
"best_epoch = 0\n",
"max_tolerance = 8\n",
Expand Down Expand Up @@ -247,6 +243,13 @@
"loss_func = CrossRatioLoss(loss_type, geo_loss, geo_loss_gamma_horz, geo_loss_gamma_vert)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -439,26 +442,10 @@
"image = cv2.imread(image_filepath)\n",
"h, w, _ = image.shape\n",
"\n",
"image = vis_tensor_and_save(image=image, h=h, w=w, tensor_output=output[1][0].cpu().data, image_name=img_name, output_uri=output_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"image = vis_tensor_and_save(image=image, h=h, w=w, tensor_output=output[1][0].cpu().data, image_name=img_name, output_uri=output_path)\n",
"\n",
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
"\n",
"pt.fig = pt.figure(figsize=(5, 5))\n",
"\n",
"pt.imshow(image)\n",
Expand Down

0 comments on commit 61f7bcd

Please sign in to comment.