Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,14 @@
"requantize.per_tensor_out(Tensor input, float in_scale, int in_zero_point, float out_scale, "
"int out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"roi_align_box_processor.out(Tensor rois, int output_size_h, int output_size_w, "
"int sampling_ratio, bool aligned, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"roi_align_box_processor(Tensor rois, int output_size_h, int output_size_w, "
"int sampling_ratio, bool aligned) -> (Tensor out)"
)

# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
aten_lib = Library("aten", "FRAGMENT")
Expand Down Expand Up @@ -1038,3 +1046,14 @@ def idma_store_impl(
channel: int = 0,
) -> torch.Tensor:
return copy_idma_copy_impl(src, task_num, channel)


@register_fake("cadence::roi_align_box_processor")
def roi_align_box_processor_meta(
rois: torch.Tensor,
output_size_h: int,
output_size_w: int,
sampling_ratio: int,
aligned: bool,
) -> torch.Tensor:
return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8)
Loading