Skip to content

Commit d799fb6

Browse files
authored
Add background removal pipeline (huggingface#1216)
* Add background removal pipeline * Add background-removal unit test * Add background removal task to docs
1 parent 31dfd43 commit d799fb6

File tree

6 files changed

+218
-12
lines changed

6 files changed

+218
-12
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
235235

236236
| Task | ID | Description | Supported? |
237237
|--------------------------|----|-------------|------------|
238+
| [Background Removal](https://huggingface.co/tasks/image-segmentation#background-removal) | `background-removal` | Isolating the main subject of an image by removing or making the background transparent. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.BackgroundRemovalPipeline)<br>[(models)](https://huggingface.co/models?other=background-removal&library=transformers.js) |
238239
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
239240
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
240241
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |

docs/snippets/5_supported-tasks.snippet

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
| Task | ID | Description | Supported? |
2424
|--------------------------|----|-------------|------------|
25+
| [Background Removal](https://huggingface.co/tasks/image-segmentation#background-removal) | `background-removal` | Isolating the main subject of an image by removing or making the background transparent. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.BackgroundRemovalPipeline)<br>[(models)](https://huggingface.co/models?other=background-removal&library=transformers.js) |
2526
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
2627
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
2728
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |

src/models.js

+39-5
Original file line numberDiff line numberDiff line change
@@ -5238,6 +5238,7 @@ export class SwinForImageClassification extends SwinPreTrainedModel {
52385238
return new SequenceClassifierOutput(await super._call(model_inputs));
52395239
}
52405240
}
5241+
export class SwinForSemanticSegmentation extends SwinPreTrainedModel { }
52415242
//////////////////////////////////////////////////
52425243

52435244
//////////////////////////////////////////////////
@@ -6840,6 +6841,8 @@ export class MobileNetV1ForImageClassification extends MobileNetV1PreTrainedMode
68406841
return new SequenceClassifierOutput(await super._call(model_inputs));
68416842
}
68426843
}
6844+
6845+
export class MobileNetV1ForSemanticSegmentation extends MobileNetV1PreTrainedModel { }
68436846
//////////////////////////////////////////////////
68446847

68456848
//////////////////////////////////////////////////
@@ -6863,6 +6866,7 @@ export class MobileNetV2ForImageClassification extends MobileNetV2PreTrainedMode
68636866
return new SequenceClassifierOutput(await super._call(model_inputs));
68646867
}
68656868
}
6869+
export class MobileNetV2ForSemanticSegmentation extends MobileNetV2PreTrainedModel { }
68666870
//////////////////////////////////////////////////
68676871

68686872
//////////////////////////////////////////////////
@@ -6886,6 +6890,7 @@ export class MobileNetV3ForImageClassification extends MobileNetV3PreTrainedMode
68866890
return new SequenceClassifierOutput(await super._call(model_inputs));
68876891
}
68886892
}
6893+
export class MobileNetV3ForSemanticSegmentation extends MobileNetV3PreTrainedModel { }
68896894
//////////////////////////////////////////////////
68906895

68916896
//////////////////////////////////////////////////
@@ -6909,6 +6914,7 @@ export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedMode
69096914
return new SequenceClassifierOutput(await super._call(model_inputs));
69106915
}
69116916
}
6917+
export class MobileNetV4ForSemanticSegmentation extends MobileNetV4PreTrainedModel { }
69126918
//////////////////////////////////////////////////
69136919

69146920
//////////////////////////////////////////////////
@@ -7322,20 +7328,29 @@ export class PretrainedMixin {
73227328
if (!this.MODEL_CLASS_MAPPINGS) {
73237329
throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
73247330
}
7325-
7331+
const model_type = options.config.model_type;
73267332
for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
7327-
const modelInfo = MODEL_CLASS_MAPPING.get(options.config.model_type);
7333+
let modelInfo = MODEL_CLASS_MAPPING.get(model_type);
73287334
if (!modelInfo) {
7329-
continue; // Item not found in this mapping
7335+
// As a fallback, we check if model_type is specified as the exact class
7336+
for (const cls of MODEL_CLASS_MAPPING.values()) {
7337+
if (cls[0] === model_type) {
7338+
modelInfo = cls;
7339+
break;
7340+
}
7341+
}
7342+
if (!modelInfo) continue; // Item not found in this mapping
73307343
}
73317344
return await modelInfo[1].from_pretrained(pretrained_model_name_or_path, options);
73327345
}
73337346

73347347
if (this.BASE_IF_FAIL) {
7335-
console.warn(`Unknown model class "${options.config.model_type}", attempting to construct from base class.`);
7348+
if (!(CUSTOM_ARCHITECTURES.has(model_type))) {
7349+
console.warn(`Unknown model class "${model_type}", attempting to construct from base class.`);
7350+
}
73367351
return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options);
73377352
} else {
7338-
throw Error(`Unsupported model type: ${options.config.model_type}`)
7353+
throw Error(`Unsupported model type: ${model_type}`)
73397354
}
73407355
}
73417356
}
@@ -7693,6 +7708,12 @@ const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
76937708
const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
76947709
['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
76957710
['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],
7711+
7712+
['swin', ['SwinForSemanticSegmentation', SwinForSemanticSegmentation]],
7713+
['mobilenet_v1', ['MobileNetV1ForSemanticSegmentation', MobileNetV1ForSemanticSegmentation]],
7714+
['mobilenet_v2', ['MobileNetV2ForSemanticSegmentation', MobileNetV2ForSemanticSegmentation]],
7715+
['mobilenet_v3', ['MobileNetV3ForSemanticSegmentation', MobileNetV3ForSemanticSegmentation]],
7716+
['mobilenet_v4', ['MobileNetV4ForSemanticSegmentation', MobileNetV4ForSemanticSegmentation]],
76967717
]);
76977718

76987719
const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
@@ -7845,6 +7866,19 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
78457866
MODEL_NAME_TO_CLASS_MAPPING.set(name, model);
78467867
}
78477868

7869+
const CUSTOM_ARCHITECTURES = new Map([
7870+
['modnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
7871+
['birefnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
7872+
['isnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
7873+
['ben', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
7874+
]);
7875+
for (const [name, mapping] of CUSTOM_ARCHITECTURES.entries()) {
7876+
mapping.set(name, ['PreTrainedModel', PreTrainedModel])
7877+
MODEL_TYPE_MAPPING.set(name, MODEL_TYPES.EncoderOnly);
7878+
MODEL_CLASS_TO_NAME_MAPPING.set(PreTrainedModel, name);
7879+
MODEL_NAME_TO_CLASS_MAPPING.set(name, PreTrainedModel);
7880+
}
7881+
78487882

78497883
/**
78507884
* Helper class which is used to instantiate pretrained models with the `from_pretrained` function.

src/pipelines.js

+106-7
Original file line numberDiff line numberDiff line change
@@ -2096,7 +2096,7 @@ export class ImageClassificationPipeline extends (/** @type {new (options: Image
20962096

20972097
/**
20982098
* @typedef {Object} ImageSegmentationPipelineOutput
2099-
* @property {string} label The label of the segment.
2099+
* @property {string|null} label The label of the segment.
21002100
* @property {number|null} score The score of the segment.
21012101
* @property {RawImage} mask The mask of the segment.
21022102
*
@@ -2166,14 +2166,30 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
21662166
const preparedImages = await prepareImages(images);
21672167
const imageSizes = preparedImages.map(x => [x.height, x.width]);
21682168

2169-
const { pixel_values, pixel_mask } = await this.processor(preparedImages);
2170-
const output = await this.model({ pixel_values, pixel_mask });
2169+
const inputs = await this.processor(preparedImages);
2170+
2171+
const { inputNames, outputNames } = this.model.sessions['model'];
2172+
if (!inputNames.includes('pixel_values')) {
2173+
if (inputNames.length !== 1) {
2174+
throw Error(`Expected a single input name, but got ${inputNames.length} inputs: ${inputNames}.`);
2175+
}
2176+
2177+
const newName = inputNames[0];
2178+
if (newName in inputs) {
2179+
throw Error(`Input name ${newName} already exists in the inputs.`);
2180+
}
2181+
// To ensure compatibility with certain background-removal models,
2182+
// we may need to perform a mapping of input to output names
2183+
inputs[newName] = inputs.pixel_values;
2184+
}
2185+
2186+
const output = await this.model(inputs);
21712187

21722188
let fn = null;
21732189
if (subtask !== null) {
21742190
fn = this.subtasks_mapping[subtask];
2175-
} else {
2176-
for (let [task, func] of Object.entries(this.subtasks_mapping)) {
2191+
} else if (this.processor.image_processor) {
2192+
for (const [task, func] of Object.entries(this.subtasks_mapping)) {
21772193
if (func in this.processor.image_processor) {
21782194
fn = this.processor.image_processor[func].bind(this.processor.image_processor);
21792195
subtask = task;
@@ -2187,7 +2203,23 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
21872203

21882204
/** @type {ImageSegmentationPipelineOutput[]} */
21892205
const annotation = [];
2190-
if (subtask === 'panoptic' || subtask === 'instance') {
2206+
if (!subtask) {
2207+
// Perform standard image segmentation
2208+
const result = output[outputNames[0]];
2209+
for (let i = 0; i < imageSizes.length; ++i) {
2210+
const size = imageSizes[i];
2211+
const item = result[i];
2212+
if (item.data.some(x => x < 0 || x > 1)) {
2213+
item.sigmoid_();
2214+
}
2215+
const mask = await RawImage.fromTensor(item.mul_(255).to('uint8')).resize(size[1], size[0]);
2216+
annotation.push({
2217+
label: null,
2218+
score: null,
2219+
mask
2220+
});
2221+
}
2222+
} else if (subtask === 'panoptic' || subtask === 'instance') {
21912223
const processed = fn(
21922224
output,
21932225
threshold,
@@ -2243,6 +2275,63 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
22432275
}
22442276
}
22452277

2278+
2279+
/**
2280+
* @typedef {Object} BackgroundRemovalPipelineOptions Parameters specific to image segmentation pipelines.
2281+
*
2282+
* @callback BackgroundRemovalPipelineCallback Segment the input images.
2283+
* @param {ImagePipelineInputs} images The input images.
2284+
* @param {BackgroundRemovalPipelineOptions} [options] The options to use for image segmentation.
2285+
* @returns {Promise<RawImage[]>} The images with the background removed.
2286+
*
2287+
* @typedef {ImagePipelineConstructorArgs & BackgroundRemovalPipelineCallback & Disposable} BackgroundRemovalPipelineType
2288+
*/
2289+
2290+
/**
2291+
* Background removal pipeline using certain `AutoModelForXXXSegmentation`.
2292+
* This pipeline removes the backgrounds of images.
2293+
*
2294+
* **Example:** Perform background removal with `Xenova/modnet`.
2295+
* ```javascript
2296+
* const segmenter = await pipeline('background-removal', 'Xenova/modnet');
2297+
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/portrait-of-woman_small.jpg';
2298+
* const output = await segmenter(url);
2299+
* // [
2300+
* // RawImage { data: Uint8ClampedArray(648000) [ ... ], width: 360, height: 450, channels: 4 }
2301+
* // ]
2302+
* ```
2303+
*/
2304+
export class BackgroundRemovalPipeline extends (/** @type {new (options: ImagePipelineConstructorArgs) => ImageSegmentationPipelineType} */ (ImageSegmentationPipeline)) {
2305+
/**
2306+
* Create a new BackgroundRemovalPipeline.
2307+
* @param {ImagePipelineConstructorArgs} options An object used to instantiate the pipeline.
2308+
*/
2309+
constructor(options) {
2310+
super(options);
2311+
}
2312+
2313+
/** @type {BackgroundRemovalPipelineCallback} */
2314+
async _call(images, options = {}) {
2315+
const isBatched = Array.isArray(images);
2316+
2317+
if (isBatched && images.length !== 1) {
2318+
throw Error("Background removal pipeline currently only supports a batch size of 1.");
2319+
}
2320+
2321+
const preparedImages = await prepareImages(images);
2322+
2323+
// @ts-expect-error TS2339
2324+
const masks = await super._call(images, options);
2325+
const result = preparedImages.map((img, i) => {
2326+
const cloned = img.clone();
2327+
cloned.putAlpha(masks[i].mask);
2328+
return cloned;
2329+
});
2330+
2331+
return result;
2332+
}
2333+
}
2334+
22462335
/**
22472336
* @typedef {Object} ZeroShotImageClassificationOutput
22482337
* @property {string} label The label identified by the model. It is one of the suggested `candidate_label`.
@@ -2555,7 +2644,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T
25552644
const output = await this.model({ ...text_inputs, pixel_values });
25562645

25572646
let result;
2558-
if('post_process_grounded_object_detection' in this.processor) {
2647+
if ('post_process_grounded_object_detection' in this.processor) {
25592648
// @ts-ignore
25602649
const processed = this.processor.post_process_grounded_object_detection(
25612650
output,
@@ -3135,6 +3224,16 @@ const SUPPORTED_TASKS = Object.freeze({
31353224
},
31363225
"type": "multimodal",
31373226
},
3227+
"background-removal": {
3228+
// no tokenizer
3229+
"pipeline": BackgroundRemovalPipeline,
3230+
"model": [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation, AutoModelForUniversalSegmentation],
3231+
"processor": AutoProcessor,
3232+
"default": {
3233+
"model": "Xenova/modnet",
3234+
},
3235+
"type": "image",
3236+
},
31383237

31393238
"zero-shot-image-classification": {
31403239
"tokenizer": AutoTokenizer,

tests/asset_cache.js

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ const TEST_IMAGES = Object.freeze({
2424
book_cover: BASE_URL + "book-cover.png",
2525
corgi: BASE_URL + "corgi.jpg",
2626
man_on_car: BASE_URL + "young-man-standing-and-leaning-on-car.jpg",
27+
portrait_of_woman: BASE_URL + "portrait-of-woman_small.jpg",
2728
});
2829

2930
const TEST_AUDIOS = {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import { pipeline, BackgroundRemovalPipeline, RawImage } from "../../src/transformers.js";
2+
3+
import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js";
4+
import { load_cached_image } from "../asset_cache.js";
5+
6+
const PIPELINE_ID = "background-removal";
7+
8+
export default () => {
9+
describe("Background Removal", () => {
10+
describe("Portrait Segmentation", () => {
11+
const model_id = "Xenova/modnet";
12+
/** @type {BackgroundRemovalPipeline} */
13+
let pipe;
14+
beforeAll(async () => {
15+
pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
16+
}, MAX_MODEL_LOAD_TIME);
17+
18+
it("should be an instance of BackgroundRemovalPipeline", () => {
19+
expect(pipe).toBeInstanceOf(BackgroundRemovalPipeline);
20+
});
21+
22+
it(
23+
"single",
24+
async () => {
25+
const image = await load_cached_image("portrait_of_woman");
26+
27+
const output = await pipe(image);
28+
expect(output).toHaveLength(1);
29+
expect(output[0]).toBeInstanceOf(RawImage);
30+
expect(output[0].width).toEqual(image.width);
31+
expect(output[0].height).toEqual(image.height);
32+
expect(output[0].channels).toEqual(4); // With alpha channel
33+
},
34+
MAX_TEST_EXECUTION_TIME,
35+
);
36+
37+
afterAll(async () => {
38+
await pipe.dispose();
39+
}, MAX_MODEL_DISPOSE_TIME);
40+
});
41+
42+
describe("Selfie Segmentation", () => {
43+
const model_id = "onnx-community/mediapipe_selfie_segmentation";
44+
/** @type {BackgroundRemovalPipeline } */
45+
let pipe;
46+
beforeAll(async () => {
47+
pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
48+
}, MAX_MODEL_LOAD_TIME);
49+
50+
it(
51+
"single",
52+
async () => {
53+
const image = await load_cached_image("portrait_of_woman");
54+
55+
const output = await pipe(image);
56+
expect(output).toHaveLength(1);
57+
expect(output[0]).toBeInstanceOf(RawImage);
58+
expect(output[0].width).toEqual(image.width);
59+
expect(output[0].height).toEqual(image.height);
60+
expect(output[0].channels).toEqual(4); // With alpha channel
61+
},
62+
MAX_TEST_EXECUTION_TIME,
63+
);
64+
65+
afterAll(async () => {
66+
await pipe.dispose();
67+
}, MAX_MODEL_DISPOSE_TIME);
68+
});
69+
});
70+
};

0 commit comments

Comments
 (0)