Skip to content

Commit 1bc3571

Browse files
qihqifacebook-github-bot
authored andcommitted
[pytorch][PR] Add ability for a mobile::Module to save as flatbuffer (pytorch#70201)
Summary: Pull Request resolved: pytorch#70201 Included functions: save_mobile_module -> saves a mobile::Module to flatbuffer load_mobile_module_from_file -> loads a flatbuffer into mobile::Module parse_mobile_module -> parses from bytes or deserialized flatbuffer module object Compared to previous attempts, this diff only adds flatbuffer to cmake target and leaves fbcode/xplat ones unchanged. Test Plan: unittest Reviewed By: malfet, gmagogsfm Differential Revision: D33239362 fbshipit-source-id: b9ca36b83d6af2d78cc50b9eb9e2a6fa7fce0763
1 parent 7a93d8b commit 1bc3571

20 files changed

+5132
-2
lines changed

.github/workflows/lint.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
- name: Ensure canonical include
5050
if: always()
5151
run: |
52-
(! git --no-pager grep -In $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' || (echo "The above lines have include with quotes; please convert them to #include <xxxx>"; false))
52+
(! git --no-pager grep -In $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' ':(exclude)torch/csrc/jit/serialization/mobile_bytecode_generated.h'|| (echo "The above lines have include with quotes; please convert them to #include <xxxx>"; false))
5353
- name: Ensure no versionless Python shebangs
5454
if: always()
5555
run: |

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,6 @@
142142
[submodule "third_party/breakpad"]
143143
path = third_party/breakpad
144144
url = https://github.com/driazati/breakpad.git
145+
[submodule "third_party/flatbuffers"]
146+
path = third_party/flatbuffers
147+
url = https://github.com/google/flatbuffers.git

BUILD.bazel

+3
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,7 @@ cc_library(
16921692
":aten_headers",
16931693
":caffe2_headers",
16941694
"//c10:headers",
1695+
"@com_github_google_flatbuffers//:flatbuffers",
16951696
"@local_config_python//:python_headers",
16961697
"@onnx",
16971698
],
@@ -1725,6 +1726,8 @@ cc_library(
17251726
],
17261727
)) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + jit_core_sources + [
17271728
":cpp_generated_code",
1729+
"torch/csrc/jit/serialization/flatbuffer_serializer.cpp",
1730+
"torch/csrc/jit/mobile/flatbuffer_loader.cpp"
17281731
],
17291732
copts = TORCH_COPTS,
17301733
defines = [

WORKSPACE

+5
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,8 @@ new_local_repository(
197197
build_file = "@//third_party:cudnn.BUILD",
198198
path = "/usr/",
199199
)
200+
201+
local_repository(
202+
name = "com_github_google_flatbuffers",
203+
path = "third_party/flatbuffers",
204+
)

caffe2/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
560560
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
561561
${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp
562562
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
563+
${TORCH_SRC_DIR}/csrc/jit/mobile/flatbuffer_loader.cpp
563564
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
564565
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
565566
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
@@ -595,6 +596,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
595596
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
596597
${TORCH_SRC_DIR}/csrc/jit/serialization/export_bytecode.cpp
597598
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
599+
${TORCH_SRC_DIR}/csrc/jit/serialization/flatbuffer_serializer.cpp
598600
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp
599601
${TORCH_SRC_DIR}/csrc/jit/api/module_save.cpp
600602
${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp
@@ -1645,6 +1647,9 @@ if(APPLE AND USE_PYTORCH_METAL)
16451647
endif()
16461648
endif()
16471649

1650+
1651+
target_link_libraries(torch_cpu PRIVATE flatbuffers)
1652+
16481653
# Note [Global dependencies]
16491654
# Some libraries (e.g. OpenMPI) like to dlopen plugins after they're initialized,
16501655
# and they assume that all of their symbols will be available in the global namespace.

cmake/Dependencies.cmake

+3
Original file line numberDiff line numberDiff line change
@@ -1996,3 +1996,6 @@ if(USE_KINETO)
19961996
message(STATUS "Configured Kineto")
19971997
endif()
19981998
endif()
1999+
2000+
# Include google/FlatBuffers
2001+
include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake)

cmake/FlatBuffers.cmake

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
set(FlatBuffers_Include ${PROJECT_SOURCE_DIR}/third_party/flatbuffers/include)
2+
file(GLOB FlatBuffers_Library_SRCS
3+
${FlatBuffers_Include}/flatbuffers/*.h
4+
)
5+
add_library(flatbuffers INTERFACE)
6+
target_sources(
7+
flatbuffers
8+
INTERFACE ${FlatBuffers_Library_SRCS}
9+
)
10+
target_include_directories(flatbuffers INTERFACE ${FlatBuffers_Include})

test/cpp/jit/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ set(JIT_TEST_SRCS
8989
${JIT_TEST_ROOT}/test_script_profile.cpp
9090
${JIT_TEST_ROOT}/test_shape_analysis.cpp
9191
${JIT_TEST_ROOT}/test_jit_logging_levels.cpp
92+
${JIT_TEST_ROOT}/test_flatbuffer.cpp
9293
)
9394

9495
if(USE_CUDA)
@@ -101,6 +102,10 @@ add_executable(test_jit
101102
${JIT_TEST_SRCS}
102103
)
103104

105+
target_link_libraries(
106+
test_jit PRIVATE flatbuffers)
107+
108+
104109
# TODO temporary until we can delete the old gtest polyfills.
105110
target_compile_definitions(test_jit PRIVATE USE_GTEST)
106111

0 commit comments

Comments
 (0)