Skip to content

Commit f112349

Browse files
authored
Object-detection pipeline improvements + better documentation (#189)
* Fix variable name * Add pipeline loading options section * Align object detection pipeline output with python library * Update unit tests * Update batched object detection unit test * Relax object detection unit tests
1 parent 13efa96 commit f112349

File tree

3 files changed

+156
-49
lines changed

3 files changed

+156
-49
lines changed

docs/source/pipelines.mdx

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,33 @@ let result = await transcriber('https://huggingface.co/datasets/Narsil/asr_dummy
6868

6969
## Pipeline options
7070

71+
### Loading
72+
73+
We offer a variety of options to control how models are loaded from the Hugging Face Hub (or locally).
74+
By default, the *quantized* version of the model is used, which is smaller and faster, but usually less accurate.
75+
To override this behaviour (i.e., use the unquantized model), you can use a custom `PretrainedOptions` object
76+
as the third parameter to the `pipeline` function:
77+
78+
```javascript
79+
// Allocation a pipeline for feature extraction, using the unquantized model
80+
const pipe = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2', {
81+
quantized: false,
82+
});
83+
```
84+
85+
You can also specify which revision of the model to use, by passing a `revision` parameter.
86+
Since the Hugging Face Hub uses a git-based versioning system, you can use any valid git revision specifier (e.g., branch name or commit hash)
87+
88+
```javascript
89+
let transcriber = await pipeline('automatic-speech-recognition', 'Xenova/whisper-tiny.en', {
90+
revision: 'output_attentions',
91+
});
92+
```
93+
94+
For the full list of options, check out the [PretrainedOptions](/api/utils/hub#module_utils/hub..PretrainedOptions) documentation.
95+
96+
97+
### Running
7198
Many pipelines have additional options that you can specify. For example, when using a model that does multilingual translation, you can specify the source and target languages like this:
7299

73100
<!-- TODO: Replace 'Xenova/nllb-200-distilled-600M' with 'facebook/nllb-200-distilled-600M' -->

src/pipelines.js

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
* ```javascript
66
* import { pipeline } from '@xenova/transformers';
77
*
8-
* let pipeline = await pipeline('sentiment-analysis');
9-
* let result = await pipeline('I love transformers!');
8+
* let classifier = await pipeline('sentiment-analysis');
9+
* let result = await classifier('I love transformers!');
1010
* // [{'label': 'POSITIVE', 'score': 0.999817686}]
1111
* ```
1212
*
@@ -1317,6 +1317,26 @@ export class ZeroShotImageClassificationPipeline extends Pipeline {
13171317
/**
13181318
* Object detection pipeline using any `AutoModelForObjectDetection`.
13191319
* This pipeline predicts bounding boxes of objects and their classes.
1320+
*
1321+
* **Example:** Run object-detection with `facebook/detr-resnet-50`.
1322+
* ```javascript
1323+
* let img = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg';
1324+
*
1325+
* let detector = await pipeline('object-detection', 'Xenova/detr-resnet-50');
1326+
* let output = await detector(img, { threshold: 0.9 });
1327+
* // [{
1328+
* // "score": 0.9976370930671692,
1329+
* // "label": "remote",
1330+
* // "box": { "xmin": 31, "ymin": 68, "xmax": 190, "ymax": 118 }
1331+
* // },
1332+
* // ...
1333+
* // {
1334+
* // "score": 0.9984092116355896,
1335+
* // "label": "cat",
1336+
* // "box": { "xmin": 331, "ymin": 19, "xmax": 649, "ymax": 371 }
1337+
* // }]
1338+
* ```
1339+
*
13201340
* @extends Pipeline
13211341
*/
13221342
export class ObjectDetectionPipeline extends Pipeline {
@@ -1359,9 +1379,35 @@ export class ObjectDetectionPipeline extends Pipeline {
13591379

13601380
// Add labels
13611381
let id2label = this.model.config.id2label;
1362-
processed.forEach(x => x.labels = x.classes.map(y => id2label[y]));
13631382

1364-
return isBatched ? processed : processed[0];
1383+
// Format output
1384+
const result = processed.map(batch => {
1385+
return batch.boxes.map((box, i) => {
1386+
return {
1387+
score: batch.scores[i],
1388+
label: id2label[batch.classes[i]],
1389+
box: this._get_bounding_box(box, !percentage),
1390+
}
1391+
})
1392+
})
1393+
1394+
return isBatched ? result : result[0];
1395+
}
1396+
1397+
/**
1398+
* Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... }
1399+
* @param {number[]} box The bounding box as a list.
1400+
* @param {boolean} asInteger Whether to cast to integers.
1401+
* @returns {Object} The bounding box as an object.
1402+
* @private
1403+
*/
1404+
_get_bounding_box(box, asInteger) {
1405+
if (asInteger) {
1406+
box = box.map(x => x | 0);
1407+
}
1408+
const [xmin, ymin, xmax, ymax] = box;
1409+
1410+
return { xmin, ymin, xmax, ymax };
13651411
}
13661412
}
13671413

tests/pipelines.test.js

Lines changed: 79 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,67 +1116,101 @@ describe('Pipelines', () => {
11161116
let detector = await pipeline('object-detection', m(models[0]));
11171117

11181118
// TODO add batched test cases when supported
1119-
let url = 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/savanna.jpg';
1120-
let urls = ['https://huggingface.co/datasets/mishig/sample_images/resolve/main/airport.jpg']
1119+
let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg';
1120+
let urls = ['https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/savanna.jpg']
11211121

11221122
// single + threshold
11231123
{
11241124
let output = await detector(url, {
11251125
threshold: 0.9,
11261126
});
11271127

1128-
// let expected = {
1129-
// "boxes": [
1130-
// [352.8210112452507, 247.36732184886932, 390.5271676182747, 318.09066116809845],
1131-
// [111.15852802991867, 235.34255504608154, 224.96717244386673, 325.21119117736816],
1132-
// [13.524770736694336, 146.81672930717468, 207.97560095787048, 278.6452639102936],
1133-
// [187.396682202816, 227.97491312026978, 313.05202156305313, 300.26460886001587],
1134-
// [201.60082161426544, 230.86223602294922, 312.1393972635269, 306.5505266189575],
1135-
// [365.85242718458176, 95.3144109249115, 526.5485098958015, 313.17670941352844]
1136-
// ],
1137-
// "classes": [24, 24, 25, 24, 24, 25],
1138-
// "scores": [0.9989480376243591, 0.9990893006324768, 0.9690554738044739, 0.9274907112121582, 0.9714975953102112, 0.9989491105079651],
1139-
// "labels": ["zebra", "zebra", "giraffe", "zebra", "zebra", "giraffe"]
1140-
// };
1141-
1142-
let num_classes = output.boxes.length;
1143-
expect(num_classes).toBeGreaterThan(1);
1144-
expect(output.classes.length).toEqual(num_classes);
1145-
expect(output.scores.length).toEqual(num_classes);
1146-
expect(output.labels.length).toEqual(num_classes);
1128+
// let expected = [
1129+
// {
1130+
// "score": 0.9977124929428101,
1131+
// "label": "remote",
1132+
// "box": { "xmin": 41, "ymin": 70, "xmax": 176, "ymax": 118 }
1133+
// },
1134+
// {
1135+
// "score": 0.9984639883041382,
1136+
// "label": "remote",
1137+
// "box": { "xmin": 332, "ymin": 73, "xmax": 369, "ymax": 188 }
1138+
// },
1139+
// {
1140+
// "score": 0.9964856505393982,
1141+
// "label": "couch",
1142+
// "box": { "xmin": 0, "ymin": 1, "xmax": 639, "ymax": 474 }
1143+
// },
1144+
// {
1145+
// "score": 0.9988334774971008,
1146+
// "label": "cat",
1147+
// "box": { "xmin": 11, "ymin": 51, "xmax": 314, "ymax": 472 }
1148+
// },
1149+
// {
1150+
// "score": 0.9982513785362244,
1151+
// "label": "cat",
1152+
// "box": { "xmin": 345, "ymin": 22, "xmax": 640, "ymax": 371 }
1153+
// }
1154+
// ]
11471155

1156+
expect(output.length).toBeGreaterThan(0);
1157+
for (let cls of output) {
1158+
expect(typeof cls.score).toBe('number');
1159+
expect(typeof cls.label).toBe('string');
1160+
for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) {
1161+
expect(typeof cls.box[key]).toBe('number');
1162+
}
1163+
}
11481164
}
11491165

1150-
// single + threshold + percentage
1166+
// batched + threshold + percentage
11511167
{
11521168
let output = await detector(urls, {
11531169
threshold: 0.9,
11541170
percentage: true
11551171
});
1172+
// let expected = [[
1173+
// {
1174+
// score: 0.9991137385368347,
1175+
// label: 'zebra',
1176+
// box: { xmin: 0.65165576338768, ymin: 0.685152679681778, xmax: 0.723189502954483, ymax: 0.8801506459712982 }
1177+
// },
1178+
// {
1179+
// score: 0.998811662197113,
1180+
// label: 'zebra',
1181+
// box: { xmin: 0.20797613263130188, ymin: 0.6543092578649521, xmax: 0.4147692620754242, ymax: 0.9040975719690323 }
1182+
// },
1183+
// {
1184+
// score: 0.9707837104797363,
1185+
// label: 'giraffe',
1186+
// box: { xmin: 0.02498096227645874, ymin: 0.40549489855766296, xmax: 0.38669759035110474, ymax: 0.7895723879337311 }
1187+
// },
1188+
// {
1189+
// score: 0.9984336495399475,
1190+
// label: 'zebra',
1191+
// box: { xmin: 0.3540637195110321, ymin: 0.6370827257633209, xmax: 0.5765090882778168, ymax: 0.8480959832668304 }
1192+
// },
1193+
// {
1194+
// score: 0.9986463785171509,
1195+
// label: 'giraffe',
1196+
// box: { xmin: 0.6763969212770462, ymin: 0.25748637318611145, xmax: 0.974339172244072, ymax: 0.8684568107128143 }
1197+
// }
1198+
// ]]
1199+
1200+
expect(output).toHaveLength(urls.length); // Same number of inputs as outputs
1201+
1202+
for (let i = 0; i < output.length; ++i) {
1203+
expect(output[i].length).toBeGreaterThan(0);
1204+
for (let cls of output[i]) {
1205+
expect(typeof cls.score).toBe('number');
1206+
expect(typeof cls.label).toBe('string');
1207+
for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) {
1208+
expect(typeof cls.box[key]).toBe('number');
1209+
}
1210+
}
1211+
}
1212+
11561213

1157-
// let expected = [{
1158-
// "boxes": [
1159-
// [0.7231650948524475, 0.32641804218292236, 0.981127917766571, 0.9918863773345947],
1160-
// [0.7529061436653137, 0.52558633685112, 0.8229959607124329, 0.6482008993625641],
1161-
// [0.5080368518829346, 0.5156279355287552, 0.5494132041931152, 0.5434067696332932],
1162-
// [0.33636586368083954, 0.5217841267585754, 0.3535611182451248, 0.6151944994926453],
1163-
// [0.42090220749378204, 0.4482414871454239, 0.5515891760587692, 0.5207531303167343],
1164-
// [0.1988394856452942, 0.41224047541618347, 0.45213085412979126, 0.5206181704998016],
1165-
// [0.5063001662492752, 0.5170856416225433, 0.5478668659925461, 0.54373899102211],
1166-
// [0.5734506398439407, 0.4508090913295746, 0.7049560993909836, 0.6252130568027496],
1167-
// ],
1168-
// "classes": [6, 1, 8, 1, 5, 5, 3, 6],
1169-
// "scores": [0.9970788359642029, 0.996989905834198, 0.9505048990249634, 0.9984546899795532, 0.9942372441291809, 0.9989550709724426, 0.938920259475708, 0.9992448091506958],
1170-
// "labels": ["bus", "person", "truck", "person", "airplane", "airplane", "car", "bus"]
1171-
// }];
1172-
1173-
expect(output).toHaveLength(urls.length);
1174-
1175-
let num_classes = output[0].boxes.length;
1176-
expect(num_classes).toBeGreaterThan(1);
1177-
expect(output[0].classes.length).toEqual(num_classes);
1178-
expect(output[0].scores.length).toEqual(num_classes);
1179-
expect(output[0].labels.length).toEqual(num_classes);
11801214
}
11811215

11821216
await detector.dispose();

0 commit comments

Comments
 (0)