Skip to content

Commit 83dfa47

Browse files
authored
Add depth-estimation w/ DPT and GLPN (#389)
* Add `size` getter to `RawImage` * Add `DPTFeatureExtractor` * Add depth-estimation w/ DPT models * Add GLPN models for depth estimation * Add missing import in example * Add `DPTFeatureExtractor` processor test * Add unit test for GLPN processor * Add support for `GLPNFeatureExtractor` Uses `size_divisor` to determine resize width and height * Add `GLPNForDepthEstimation` example code * Add DPT to list of supported models * Add GLPN to list of supported models * Add `DepthEstimationPipeline` * Add listed support for depth estimation pipeline * Add depth estimation pipeline unit tests * Fix formatting * Update `pipeline` JSDoc * Fix typo from merge
1 parent 5ddc472 commit 83dfa47

File tree

10 files changed

+293
-5
lines changed

10 files changed

+293
-5
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
210210

211211
| Task | ID | Description | Supported? |
212212
|--------------------------|----|-------------|------------|
213-
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | |
213+
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
214214
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
215215
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |
216216
| [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToImagePipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-to-image&library=transformers.js) |
@@ -277,8 +277,10 @@ You can refine your search by selecting the task you're interested in (e.g., [te
277277
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
278278
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
279279
1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
280+
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
280281
1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme.
281282
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
283+
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
282284
1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
283285
1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach
284286
1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.

docs/snippets/5_supported-tasks.snippet

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
| Task | ID | Description | Supported? |
2424
|--------------------------|----|-------------|------------|
25-
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | |
25+
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
2626
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
2727
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |
2828
| [Image-to-Image](https://huggingface.co/tasks/image-to-image) | `image-to-image` | Transforming a source image to match the characteristics of a target image or a target image domain. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageToImagePipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-to-image&library=transformers.js) |

docs/snippets/6_supported-models.snippet

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
1919
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
2020
1. **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
21+
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
2122
1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme.
2223
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
24+
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
2325
1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
2426
1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach
2527
1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.

scripts/supported_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,21 @@
208208
# Document Question Answering
209209
'naver-clova-ix/donut-base-finetuned-docvqa',
210210
],
211+
'dpt': [
212+
# Depth estimation
213+
'Intel/dpt-hybrid-midas',
214+
'Intel/dpt-large',
215+
],
211216
'falcon': [
212217
# Text generation
213218
'Rocketknight1/tiny-random-falcon-7b',
214219
'fxmarty/really-tiny-falcon-testing',
215220
],
221+
'glpn': [
222+
# Depth estimation
223+
'vinvino02/glpn-kitti',
224+
'vinvino02/glpn-nyu',
225+
],
216226
'gpt_neo': [
217227
# Text generation
218228
'EleutherAI/gpt-neo-125M',

src/models.js

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3371,6 +3371,100 @@ export class Swin2SRModel extends Swin2SRPreTrainedModel { }
33713371
export class Swin2SRForImageSuperResolution extends Swin2SRPreTrainedModel { }
33723372
//////////////////////////////////////////////////
33733373

3374+
//////////////////////////////////////////////////
3375+
export class DPTPreTrainedModel extends PreTrainedModel { }
3376+
3377+
/**
3378+
* The bare DPT Model transformer outputting raw hidden-states without any specific head on top.
3379+
*/
3380+
export class DPTModel extends DPTPreTrainedModel { }
3381+
3382+
/**
3383+
* DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
3384+
*
3385+
* **Example:** Depth estimation w/ `Xenova/dpt-hybrid-midas`.
3386+
* ```javascript
3387+
* import { DPTForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@xenova/transformers';
3388+
*
3389+
* // Load model and processor
3390+
* const model_id = 'Xenova/dpt-hybrid-midas';
3391+
* const model = await DPTForDepthEstimation.from_pretrained(model_id);
3392+
* const processor = await AutoProcessor.from_pretrained(model_id);
3393+
*
3394+
* // Load image from URL
3395+
* const url = 'http://images.cocodataset.org/val2017/000000039769.jpg';
3396+
* const image = await RawImage.fromURL(url);
3397+
*
3398+
* // Prepare image for the model
3399+
* const inputs = await processor(image);
3400+
*
3401+
* // Run model
3402+
* const { predicted_depth } = await model(inputs);
3403+
*
3404+
* // Interpolate to original size
3405+
* const prediction = interpolate(predicted_depth, image.size.reverse(), 'bilinear', false);
3406+
*
3407+
* // Visualize the prediction
3408+
* const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8');
3409+
* const depth = RawImage.fromTensor(formatted);
3410+
* // RawImage {
3411+
* // data: Uint8Array(307200) [ 85, 85, 84, ... ],
3412+
* // width: 640,
3413+
* // height: 480,
3414+
* // channels: 1
3415+
* // }
3416+
* ```
3417+
*/
3418+
export class DPTForDepthEstimation extends DPTPreTrainedModel { }
3419+
//////////////////////////////////////////////////
3420+
3421+
//////////////////////////////////////////////////
3422+
export class GLPNPreTrainedModel extends PreTrainedModel { }
3423+
3424+
/**
3425+
* The bare GLPN encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.
3426+
*/
3427+
export class GLPNModel extends GLPNPreTrainedModel { }
3428+
3429+
/**
3430+
* GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2.
3431+
*
3432+
* **Example:** Depth estimation w/ `Xenova/glpn-kitti`.
3433+
* ```javascript
3434+
* import { GLPNForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@xenova/transformers';
3435+
*
3436+
* // Load model and processor
3437+
* const model_id = 'Xenova/glpn-kitti';
3438+
* const model = await GLPNForDepthEstimation.from_pretrained(model_id);
3439+
* const processor = await AutoProcessor.from_pretrained(model_id);
3440+
*
3441+
* // Load image from URL
3442+
* const url = 'http://images.cocodataset.org/val2017/000000039769.jpg';
3443+
* const image = await RawImage.fromURL(url);
3444+
*
3445+
* // Prepare image for the model
3446+
* const inputs = await processor(image);
3447+
*
3448+
* // Run model
3449+
* const { predicted_depth } = await model(inputs);
3450+
*
3451+
* // Interpolate to original size
3452+
* const prediction = interpolate(predicted_depth, image.size.reverse(), 'bilinear', false);
3453+
*
3454+
* // Visualize the prediction
3455+
* const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8');
3456+
* const depth = RawImage.fromTensor(formatted);
3457+
* // RawImage {
3458+
* // data: Uint8Array(307200) [ 207, 169, 154, ... ],
3459+
* // width: 640,
3460+
* // height: 480,
3461+
* // channels: 1
3462+
* // }
3463+
* ```
3464+
*/
3465+
export class GLPNForDepthEstimation extends GLPNPreTrainedModel { }
3466+
//////////////////////////////////////////////////
3467+
33743468
//////////////////////////////////////////////////
33753469
export class DonutSwinPreTrainedModel extends PreTrainedModel { }
33763470

@@ -4025,6 +4119,8 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
40254119
['swin2sr', ['Swin2SRModel', Swin2SRModel]],
40264120
['donut-swin', ['DonutSwinModel', DonutSwinModel]],
40274121
['yolos', ['YolosModel', YolosModel]],
4122+
['dpt', ['DPTModel', DPTModel]],
4123+
['glpn', ['GLPNModel', GLPNModel]],
40284124

40294125
['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],
40304126

@@ -4205,6 +4301,11 @@ const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([
42054301
['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]],
42064302
])
42074303

4304+
const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([
4305+
['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]],
4306+
['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]],
4307+
])
4308+
42084309

42094310
const MODEL_CLASS_TYPE_MAPPING = [
42104311
[MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly],
@@ -4221,6 +4322,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
42214322
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42224323
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42234324
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
4325+
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42244326
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42254327
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42264328
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
@@ -4425,6 +4527,10 @@ export class AutoModelForImageToImage extends PretrainedMixin {
44254527
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES];
44264528
}
44274529

4530+
export class AutoModelForDepthEstimation extends PretrainedMixin {
4531+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES];
4532+
}
4533+
44284534
//////////////////////////////////////////////////
44294535

44304536
//////////////////////////////////////////////////

0 commit comments

Comments
 (0)