Skip to content

Commit aa26498

Browse files
NicolasHugfmassa
andauthored
Add ops-cpp target to torchvision (#3350)
Summary: This diff adds a new target to torchvision which enables users to use torchvision ops from C++. For now, the `cpp_library` is not used by the `python_cpp_library`. We should instead refactor the logic in torchvision to directly use `cpp_library` instead. There is currently an inconsistency between fbcode and OSS users. OSS users can import torchvision via ``` #include <torchvision/vision.h> ``` while fbcode users need to do ``` #include <torchvision/csrc/vision.h> ``` It would be good to fix this discrepancy in the future. I didn't directly use `test_frcnn_tracing.cpp` due to complications for getting the `.pt` file in a way that works for both OSS and fbcode, so instead we added a self-contained test that should validate that the torchvision ops are properly registered and visible to JIT Reviewed By: datumbox Differential Revision: D26225669 fbshipit-source-id: 5dd9fb98dd58e854f95806e4860d02f54fc04ea4 Co-authored-by: Francisco Massa <[email protected]>
1 parent 88d18d6 commit aa26498

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

test/cpp/test_custom_operators.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright 2004-present Facebook. All Rights Reserved.
2+
3+
#include <gtest/gtest.h>
4+
#include <torch/script.h>
5+
#include <torch/torch.h>
6+
7+
// FIXME: the include path differs from OSS due to the extra csrc
8+
#include <torchvision/csrc/ops/nms.h>
9+
10+
TEST(test_custom_operators, nms) {
11+
// make sure that the torchvision ops are visible to the jit interpreter
12+
auto& ops = torch::jit::getAllOperatorsFor(torch::jit::Symbol::fromQualString("torchvision::nms"));
13+
ASSERT_EQ(ops.size(), 1);
14+
15+
auto& op = ops.front();
16+
ASSERT_EQ(op->schema().name(), "torchvision::nms");
17+
18+
torch::jit::Stack stack;
19+
at::Tensor boxes = at::rand({50, 4}), scores = at::rand({50});
20+
double thresh = 0.7;
21+
22+
torch::jit::push(stack, boxes, scores, thresh);
23+
op->getOperation()(&stack);
24+
at::Tensor output_jit;
25+
torch::jit::pop(stack, output_jit);
26+
27+
at::Tensor output = vision::ops::nms(boxes, scores, thresh);
28+
ASSERT_TRUE(output_jit.allclose(output));
29+
30+
}

0 commit comments

Comments
 (0)