Skip to content

Commit 360ca5b

Browse files
committed
Ruff: format Jupyter notebooks too
1 parent ccccadd commit 360ca5b

5 files changed

+5511
-31
lines changed

examples/binary_segmentation_intro.ipynb

+4,211-1
Large diffs are not rendered by default.

examples/camvid_segmentation_multiclass.ipynb

+6-8
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@
492492
" augmentation=get_validation_augmentation(),\n",
493493
")\n",
494494
"\n",
495-
"#Change to > 0 if not on Windows machine\n",
495+
"# Change to > 0 if not on Windows machine\n",
496496
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)\n",
497497
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=0)\n",
498498
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)"
@@ -545,12 +545,10 @@
545545
"import pytorch_lightning as pl\n",
546546
"import segmentation_models_pytorch as smp\n",
547547
"import torch\n",
548-
"import torch.nn.functional as F\n",
549548
"from torch.optim import lr_scheduler\n",
550549
"\n",
551550
"\n",
552551
"class CamVidModel(pl.LightningModule):\n",
553-
"\n",
554552
" def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):\n",
555553
" super().__init__()\n",
556554
" self.model = smp.create_model(\n",
@@ -591,13 +589,14 @@
591589
" mask = mask.long()\n",
592590
"\n",
593591
" # Mask shape\n",
594-
" assert mask.ndim == 3 # [batch_size, H, W]\n",
592+
" assert mask.ndim == 3 # [batch_size, H, W]\n",
595593
"\n",
596594
" # Predict mask logits\n",
597595
" logits_mask = self.forward(image)\n",
598-
" \n",
599-
" assert logits_mask.shape[1] == self.number_of_classes # [batch_size, number_of_classes, H, W]\n",
600-
" \n",
596+
"\n",
597+
" assert (\n",
598+
" logits_mask.shape[1] == self.number_of_classes\n",
599+
" ) # [batch_size, number_of_classes, H, W]\n",
601600
"\n",
602601
" # Ensure the logits mask is contiguous\n",
603602
" logits_mask = logits_mask.contiguous()\n",
@@ -1678,7 +1677,6 @@
16781677
}
16791678
],
16801679
"source": [
1681-
"import matplotlib.pyplot as plt\n",
16821680
"import numpy as np\n",
16831681
"\n",
16841682
"# Fetch a batch from the test loader\n",

examples/cars segmentation (camvid).ipynb

+1,270-1
Large diffs are not rendered by default.

examples/save_load_model_and_share_with_hf_hub.ipynb

+20-21
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,13 @@
6464
"# save the model\n",
6565
"model.save_pretrained(\n",
6666
" \"saved-model-dir/unet-with-metadata/\",\n",
67-
"\n",
6867
" # additional information to be saved with the model\n",
6968
" # only \"dataset\" and \"metrics\" are supported\n",
7069
" dataset=\"PASCAL VOC\", # only string name is supported\n",
71-
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
70+
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
7271
" \"mIoU\": 0.95,\n",
73-
" \"accuracy\": 0.96\n",
74-
" }\n",
72+
" \"accuracy\": 0.96,\n",
73+
" },\n",
7574
")"
7675
]
7776
},
@@ -222,13 +221,10 @@
222221
"# save the model and share it on the HF Hub (https://huggingface.co/models)\n",
223222
"model.save_pretrained(\n",
224223
" \"qubvel-hf/unet-with-metadata/\",\n",
225-
" push_to_hub=True, # <---------- push the model to the hub\n",
226-
" private=False, # <---------- make the model private or or public\n",
224+
" push_to_hub=True, # <---------- push the model to the hub\n",
225+
" private=False, # <---------- make the model private or or public\n",
227226
" dataset=\"PASCAL VOC\",\n",
228-
" metrics={\n",
229-
" \"mIoU\": 0.95,\n",
230-
" \"accuracy\": 0.96\n",
231-
" }\n",
227+
" metrics={\"mIoU\": 0.95, \"accuracy\": 0.96},\n",
232228
")\n",
233229
"\n",
234230
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
@@ -267,10 +263,7 @@
267263
"outputs": [],
268264
"source": [
269265
"# define a preprocessing transform for image that would be used during inference\n",
270-
"preprocessing_transform = A.Compose([\n",
271-
" A.Resize(256, 256),\n",
272-
" A.Normalize()\n",
273-
"])\n",
266+
"preprocessing_transform = A.Compose([A.Resize(256, 256), A.Normalize()])\n",
274267
"\n",
275268
"model = smp.Unet()"
276269
]
@@ -367,15 +360,21 @@
367360
"# You can also save training augmentations to the Hub too (and load it back)!\n",
368361
"#! Just make sure to provide key=\"train\" when saving and loading the augmentations.\n",
369362
"\n",
370-
"train_augmentations = A.Compose([\n",
371-
" A.HorizontalFlip(p=0.5),\n",
372-
" A.RandomBrightnessContrast(p=0.2),\n",
373-
" A.ShiftScaleRotate(p=0.5),\n",
374-
"])\n",
363+
"train_augmentations = A.Compose(\n",
364+
" [\n",
365+
" A.HorizontalFlip(p=0.5),\n",
366+
" A.RandomBrightnessContrast(p=0.2),\n",
367+
" A.ShiftScaleRotate(p=0.5),\n",
368+
" ]\n",
369+
")\n",
375370
"\n",
376-
"train_augmentations.save_pretrained(directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True)\n",
371+
"train_augmentations.save_pretrained(\n",
372+
" directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True\n",
373+
")\n",
377374
"\n",
378-
"restored_train_augmentations = A.Compose.from_pretrained(directory_or_repo_on_the_hub, key=\"train\")\n",
375+
"restored_train_augmentations = A.Compose.from_pretrained(\n",
376+
" directory_or_repo_on_the_hub, key=\"train\"\n",
377+
")\n",
379378
"print(restored_train_augmentations)"
380379
]
381380
},

pyproject.toml

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ test = [
4646
[project.urls]
4747
Homepage = 'https://github.com/qubvel-org/segmentation_models.pytorch'
4848

49+
[tool.ruff]
50+
extend-include = ['*.ipynb']
51+
fix = true
52+
4953
[tool.setuptools.dynamic]
5054
version = {attr = 'segmentation_models_pytorch.__version__.__version__'}
5155

0 commit comments

Comments
 (0)