Skip to content

Commit 2d64dbe

Browse files
committed
Enable model to be a remote argument, bypassing uploading it to DCP's server
1 parent 0f011a4 commit 2d64dbe

8 files changed

+849
-18
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ Create a `model.json` file and add your model's information to it, for example
1313
{
1414
"name": "mnist-example",
1515
"version": "0.0.1",
16+
"modelDownload": false,
1617
"model": "MNIST.onnx",
1718
"preprocess": "preprocess.py",
1819
"postprocess": "postprocess.py",
1920
"packages": ["numpy", "opencv-python"]
2021
}
2122
```
2223

23-
Where `model`, `preprocess`, and `postprocess` are all the paths to the model, preprocess, and postprocess files respectively.
24+
Where `model`, `preprocess`, and `postprocess` are all the paths to the model, preprocess, and postprocess files respectively. If your `model.json` has `modelDownload` set to a URL instead of false you do not need to upload your model, instead you must have a web server hosting the model that workers can fetch from. See `example-server.js` for an example server. DistributiveWorkers intended to work on these slices must have the URL added to their allowOrigins object.
2425

2526
Packages are all of the python libraries your pre and post processing scripts require. The list of supported packages is:
2627
```

example/example-server.js

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/**
2+
* @file example-server.js
3+
* Example web server for hosting model to be fetched by workers
4+
*
5+
* @author Ryan Saweczko, [email protected]
6+
*/
7+
8+
const express = require('express');
9+
const path = require('path');
10+
const fs = require('fs');
11+
const cors = require('cors');
12+
13+
var modelInfo;
14+
try
15+
{
16+
modelInfo = JSON.parse(fs.readFileSync('model.json', { encoding: 'utf-8' }));
17+
}
18+
catch (error)
19+
{
20+
console.error('Unable to find model.json file in directory, exiting');
21+
process.exit(1);
22+
}
23+
24+
const app = express();
25+
const port = 3500;
26+
app.use(cors());
27+
28+
const fileDirectory = path.join(__dirname);
29+
30+
app.get('/:filename', (req, res) => {
31+
const { filename } = req.params;
32+
const file = path.join(fileDirectory, filename);
33+
if (filename === modelInfo.model)
34+
{
35+
const content = fs.readFileSync(file).toString('base64');
36+
res.send(content);
37+
}
38+
else if ([modelInfo.preprocess, modelInfo.postprocess].includes(filename))
39+
{
40+
res.sendFile(file, (err) => {
41+
if (err)
42+
res.status(500).send('Error sending the file.');
43+
});
44+
}
45+
else
46+
{
47+
console.error(`request for invalid file: ${filename}`);
48+
res.status(404).send('Invalid file');
49+
}
50+
});
51+
52+
// Start server
53+
app.listen(port, () => {
54+
console.log(`Server is running at http://localhost:${port}`);
55+
});

example/model-remote.json

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"name": "mnist-example",
3+
"version": "0.0.1",
4+
"modelDownload": "http://localhost:3500",
5+
"model": "MNIST.onnx",
6+
"preprocess": "preprocess.py",
7+
"postprocess": "postprocess.py",
8+
"packages": ["numpy", "opencv-python"]
9+
}

example/model.json

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"name": "mnist-example",
33
"version": "0.0.1",
4+
"modelDownload": false,
45
"model": "MNIST.onnx",
56
"preprocess": "preprocess.py",
67
"postprocess": "postprocess.py",

inference.js

+16-6
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,28 @@ const { workFunction } = require('./workFunction');
1313
/**
1414
* Deploy work function for model inferencing via DCP
1515
*/
16-
async function deploy(inputSet, modelName, computeGroup, output, webgpu)
16+
async function deploy(inputSet, model, computeGroup, output, webgpu)
1717
{
1818
const compute = require('dcp/compute');
1919

20-
const labels = { modelName, projectID: Date.now(), debug: false, webgpu };
21-
let job = compute.for(inputSet, workFunction, [labels]);
20+
const labels = { modelName: model.name, projectID: Date.now(), debug: false, webgpu };
21+
const args = [labels];
22+
if (model.modelDownload)
23+
{
24+
const url = new URL(model.modelDownload);
25+
args.push(new URL(`/${model.preprocess}`, url));
26+
args.push(new URL(`/${model.postprocess}`, url));
27+
args.push(model.packages);
28+
args.push(new URL(`/${model.model}`, url));
29+
}
30+
let job = compute.for(inputSet, workFunction, args);
2231

23-
job.public.name = `DCP Inferencing: ${modelName}`;
32+
job.public.name = `DCP Inferencing: ${model.name}`;
2433
job.requires('onnxruntime-dcp/dcp-wasm.js');
2534
job.requires('onnxruntime-dcp/dcp-ort.js');
2635
job.requires('pyodide-core/pyodide-core.js');
27-
job.requires(`${modelName}/module.js`);
36+
if (!model.modelDownload)
37+
job.requires(`${model.name}/module.js`);
2838

2939
if (computeGroup)
3040
job.computeGroups = [computeGroup];
@@ -151,7 +161,7 @@ if (require.main === module)
151161
}
152162

153163
require('dcp-client').init().then(() => {
154-
deploy(inputSet, modelInfo.name, computeGroup, outputFile, webgpu);
164+
deploy(inputSet, modelInfo, computeGroup, outputFile, webgpu);
155165
});
156166
}
157167

0 commit comments

Comments
 (0)