Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add granularity parameter to control segmented object size #785

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ Additionally, masks can be generated for images from the command line:
python scripts/amg.py --checkpoint <path/to/checkpoint> --model-type <model_type> --input <image_or_folder> --output <path/to/output>
```

To control the granularity of segmented objects, you can use the `granularity` parameter when initializing the `SamAutomaticMaskGenerator` class or pass the `--granularity` argument when using the `amg.py` script.

Example usage:

```python
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
mask_generator = SamAutomaticMaskGenerator(sam, granularity=0.7)
masks = mask_generator.generate(<your_image>)
```

Alternatively, you can pass the `--granularity` argument when using the `amg.py` script:

```
python scripts/amg.py --checkpoint <path/to/checkpoint> --model-type <model_type> --input <image_or_folder> --output <path/to/output> --granularity 0.7
```

See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details.

<p float="left">
Expand Down Expand Up @@ -181,3 +198,24 @@ If you use SAM or SA-1B in your research, please use the following BibTeX entry.
year={2023}
}
```

## Granularity Parameter

The `SamAutomaticMaskGenerator` class now includes a `granularity` parameter that allows you to control the granularity of segmented objects. This parameter can be set when initializing the `SamAutomaticMaskGenerator` class or passed as a command-line argument when using the `amg.py` script.

### Example Usage

To use the `granularity` parameter, you can initialize the `SamAutomaticMaskGenerator` class with the desired granularity value:

```python
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
mask_generator = SamAutomaticMaskGenerator(sam, granularity=0.7)
masks = mask_generator.generate(<your_image>)
```

Alternatively, you can pass the `--granularity` argument when using the `amg.py` script:

```
python scripts/amg.py --checkpoint <path/to/checkpoint> --model-type <model_type> --input <image_or_folder> --output <path/to/output> --granularity 0.7
```
8 changes: 8 additions & 0 deletions scripts/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@
),
)

parser.add_argument(
"--granularity",
type=float,
default=0.5,
help="Set the granularity of segmented objects.",
)

amg_settings = parser.add_argument_group("AMG Settings")

amg_settings.add_argument(
Expand Down Expand Up @@ -187,6 +194,7 @@ def get_amg_kwargs(args):
"crop_overlap_ratio": args.crop_overlap_ratio,
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
"min_mask_region_area": args.min_mask_region_area,
"granularity": args.granularity,
}
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
return amg_kwargs
Expand Down
22 changes: 22 additions & 0 deletions segment_anything/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
granularity: float = 0.5,
) -> None:
"""
Using a SAM model, generates masks for the entire image.
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
granularity (float): A parameter to control the granularity of segmented objects.
"""

assert (points_per_side is None) != (
Expand Down Expand Up @@ -132,6 +134,7 @@ def __init__(
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
self.min_mask_region_area = min_mask_region_area
self.output_mode = output_mode
self.granularity = granularity

@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -170,6 +173,10 @@ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
max(self.box_nms_thresh, self.crop_nms_thresh),
)

# Blend smaller segments into larger ones based on granularity
if self.granularity > 0:
mask_data = self.blend_segments(mask_data, self.granularity)

# Encode masks
if self.output_mode == "coco_rle":
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
Expand Down Expand Up @@ -370,3 +377,18 @@ def postprocess_small_regions(
mask_data.filter(keep_by_nms)

return mask_data

def blend_segments(self, mask_data: MaskData, granularity: float) -> MaskData:
"""
Blends smaller segments into larger ones based on the granularity parameter.

Arguments:
mask_data (MaskData): The mask data containing the segments.
granularity (float): The granularity parameter to control the blending.

Returns:
MaskData: The updated mask data with blended segments.
"""
# Implement the blending logic here based on the granularity parameter.
# This is a placeholder implementation and should be replaced with the actual blending logic.
return mask_data