|
531 | 531 | "metadata": {},
|
532 | 532 | "outputs": [],
|
533 | 533 | "source": [
|
| 534 | + "import os\n", |
534 | 535 | "from d2go.runner import GeneralizedRCNNRunner\n",
|
535 | 536 | "\n",
|
536 | 537 | "\n",
|
|
543 | 544 | " cfg.DATASETS.TEST = (\"balloon_val\",)\n",
|
544 | 545 | " cfg.DATALOADER.NUM_WORKERS = 2\n",
|
545 | 546 | " cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"faster_rcnn_fbnetv3a_C4.yaml\") # Let training initialize from model zoo\n",
|
| 547 | + " cfg.MODEL.DEVICE = \"cpu\" if ('CI' in os.environ) else \"cuda\"\n", |
546 | 548 | " cfg.SOLVER.IMS_PER_BATCH = 2\n",
|
547 | 549 | " cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR\n",
|
548 |
| - " cfg.SOLVER.MAX_ITER = 600 # 600 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset\n", |
| 550 | + " cfg.SOLVER.MAX_ITER = 5 if ('CI' in os.environ) else 600 # 600 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset\n", |
549 | 551 | " cfg.SOLVER.STEPS = [] # do not decay learning rate\n",
|
550 | 552 | " cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # faster, and good enough for this toy dataset (default: 512)\n",
|
551 | 553 | " cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)\n",
|
|
609 | 611 | "from detectron2.data import build_detection_test_loader\n",
|
610 | 612 | "from d2go.export.exporter import convert_and_export_predictor\n",
|
611 | 613 | "from d2go.utils.testing.data_loader_helper import create_detection_data_loader_on_toy_dataset\n",
|
612 |
| - "from d2go.export.d2_meta_arch import patch_d2_meta_arch\n", |
613 | 614 | "\n",
|
614 | 615 | "import logging\n",
|
615 | 616 | "\n",
|
616 | 617 | "# disable all the warnings\n",
|
617 | 618 | "previous_level = logging.root.manager.disable\n",
|
618 | 619 | "logging.disable(logging.INFO)\n",
|
619 | 620 | "\n",
|
620 |
| - "patch_d2_meta_arch()\n", |
621 |
| - "\n", |
622 | 621 | "cfg_name = 'faster_rcnn_fbnetv3a_dsmask_C4.yaml'\n",
|
623 | 622 | "pytorch_model = model_zoo.get(cfg_name, trained=True, device='cpu')\n",
|
624 | 623 | "pytorch_model.eval()\n",
|
|
0 commit comments