Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX config for RT-DETR (and RT-DETRv2) #2201

Merged
merged 10 commits into from
Mar 3, 2025

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Feb 27, 2025

What does this PR do?

Add ONNX config for RT-DETR (and RT-DETRv2), continue of

Fixes: #2176

Also, FP16 ONNX tests are broken for RT-DETR but will be fixed in the coming version of transformers. What would be a better way to skip them?

PR in transformers:

Who can review?

@echarlaix

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@xenova xenova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I tested the exporter on the set of rtdetr_v2 models and these are the warnings. Most are not issues, but some may be, and could require some modifications to huggingface/transformers#36460 (review).

/usr/local/lib/python3.11/dist-packages/optimum/exporters/onnx/model_configs.py:2668: UserWarning: Exporting model with image `height=64` which is less than minimal 320, setting `height` to 320.
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/optimum/exporters/onnx/model_configs.py:2673: UserWarning: Exporting model with image `width=64` which is less than minimal 320, setting `width` to 320.
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr/modeling_rt_detr_resnet.py:107: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_channels != self.num_channels:
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:989: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  grid_w = torch.arange(int(width), device=device).to(dtype)
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:990: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  grid_h = torch.arange(int(height), device=device).to(dtype)
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:300: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:336: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:1747: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:1638: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  valid_wh = torch.tensor([width, height], device=device).to(dtype)
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:1647: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device))
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:191: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:212: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if reference_points.shape[-1] == 2:
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:218: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  elif reference_points.shape[-1] == 4:

The if checks should be able to be ignored, since most are error checking. The only possibly concerning ones are:

  1. grid_w = torch.arange(int(width), device=device).to(dtype) - tracing issues with int()
  2. grid_h = torch.arange(int(height), device=device).to(dtype) - same issue as 1
  3. spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) - since the values use during export will be used in the final model (i.e., not dynamic).
  4. valid_wh = torch.tensor([width, height], device=device).to(dtype) - same issue as 3

),
"rt-detr-v2": supported_tasks_mapping(
"object-detection",
onnx="RTDetrOnnxConfig",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add another RTDetrV2OnnxConfig class in model_configs.py and use that here? (similar to how other configs do it).

class RTDetrV2OnnxConfig(RTDetrOnnxConfig):
    pass

Comment on lines +162 to +163
"rt-detr": "PekingU/rtdetr_r18vd",
"rt-detr-v2": "PekingU/rtdetr_v2_r18vd",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have any tiny-random models, we can add them here. Although, considering that these models are pretty small already, we can probably just use these ones.

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the addition @qubvel

Comment on lines 2667 to 2676
if kwargs["height"] < 320:
warnings.warn(
f"Exporting model with image `height={kwargs['height']}` which is less than minimal 320, setting `height` to 320."
)
kwargs["height"] = 320
if kwargs["width"] < 320:
warnings.warn(
f"Exporting model with image `width={kwargs['width']}` which is less than minimal 320, setting `width` to 320."
)
kwargs["width"] = 320
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it needed and where does the value 320 comes from ?

Copy link
Member Author

@qubvel qubvel Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default height and width are set to 64, but export does not work for those values, so I override it with the minimum value divisible by 32, which is supported for export (should be greater than 300)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! In this case shouldn't we check for num_queries in the config directly as this could vary between models ? (can be set to any value by default in cases this value cannot be extracted from the self._config)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I've updated the implementation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect thanks @qubvel !

@echarlaix
Copy link
Collaborator

echarlaix commented Mar 3, 2025

Failing tests unrelated, merging, thanks again @qubvel !

@echarlaix echarlaix merged commit b6c2b5c into huggingface:main Mar 3, 2025
33 of 37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for RT-DETR model export to onnx
5 participants