Skip to content

Commit fff4062

Browse files
authored
[ML] Restrict file system access for pytorch models (#2851)
This PR ensures that the PyTorch models are not allowed to access the file system. It accomplishes the goal by inspecting the model's operations and prohibiting the loading of models with operations that read or write files.
1 parent 6e92bb9 commit fff4062

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

bin/pytorch_inference/Main.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,28 @@
4141
#include <optional>
4242
#include <string>
4343

44+
namespace {
45+
// Add more forbidden ops here if needed
46+
const std::unordered_set<std::string_view> FORBIDDEN_OPERATIONS = {"aten::from_file", "aten::save"};
47+
48+
void verifySafeModel(const torch::jit::script::Module& module_) {
49+
try {
50+
const auto method = module_.get_method("forward");
51+
for (const auto graph = method.graph(); const auto& node : graph->nodes()) {
52+
if (const std::string opName = node->kind().toQualString();
53+
FORBIDDEN_OPERATIONS.contains(opName)) {
54+
HANDLE_FATAL(<< "Loading the inference process failed because it contains forbidden operation: "
55+
<< opName);
56+
}
57+
}
58+
} catch (const c10::Error& e) {
59+
LOG_FATAL(<< "Failed to get forward method: " << e.what());
60+
}
61+
62+
LOG_DEBUG(<< "Model verified: no forbidden operations detected.");
63+
}
64+
}
65+
4466
torch::Tensor infer(torch::jit::script::Module& module_,
4567
ml::torch::CCommandParser::SRequest& request) {
4668

@@ -281,6 +303,7 @@ int main(int argc, char** argv) {
281303
return EXIT_FAILURE;
282304
}
283305
module_ = torch::jit::load(std::move(readAdapter));
306+
verifySafeModel(module_);
284307
module_.eval();
285308

286309
LOG_DEBUG(<< "model loaded");

docs/CHANGELOG.asciidoc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@
5555
* Update the PyTorch library to version 2.5.1. (See {ml-pull}2783[#2798], {ml-pull}2799[#2799].)
5656
* Upgrade Boost libraries to version 1.86. (See {ml-pull}2780[#2780], {ml-pull}2779[#2779].)
5757

58+
== {es} version 8.17.7
59+
60+
=== Enhancements
61+
* Restrict file system access for PyTorch models (See {ml-pull}2851[#2851].)
62+
5863
== {es} version 8.16.6
5964

6065
=== Bug Fixes

0 commit comments

Comments
 (0)