Skip to content

Commit fba298a

Browse files
committed
init
0 parents  commit fba298a

File tree

7 files changed

+194
-0
lines changed

7 files changed

+194
-0
lines changed

__init__.py

Whitespace-only changes.

onnxtf/__init__.py

Whitespace-only changes.

onnxtf/convert.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Backend for running ONNX on Tensorflow
2+
3+
To run this, you will need to have Tensorflow installed as well.
4+
"""
5+
from __future__ import absolute_import
6+
from __future__ import division
7+
from __future__ import print_function
8+
from __future__ import unicode_literals
9+
10+
import collections
11+
12+
import numpy as np
13+
from onnx import onnx_pb2, checker
14+
from onnx.onnx_pb2 import GraphProto, TensorProto, AttributeProto
15+
import onnx.numpy_helper
16+
import onnx.defs
17+
from onnx.backend.base import Backend, BackendRep, Device, DeviceType, namedtupledict
18+
19+
import numpy as np
20+
from onnx import onnx_pb2, helper
21+
import tensorflow as tf
22+
23+
24+
def get_device_option(device):
25+
m = {DeviceType.CPU: '/cpu:0',
26+
DeviceType.CUDA: '/gpu:0'}
27+
return m[device.type]
28+
29+
def get_type(type):
30+
t = {
31+
""
32+
}
33+
34+
class TensorflowBackend(Backend):
35+
@classmethod
36+
def run_node(cls, node, inputs, device='CPU'):
37+
super(TensorflowBackend, cls).run_node(node, inputs, device)
38+
39+
device_option = get_device_option(Device(device))
40+
input_tensors = []
41+
for i in inputs:
42+
input_tensors.append(tf.constant(i))
43+
44+
if isinstance(inputs, dict):
45+
feed_dict_raw = inputs
46+
else:
47+
assert(len(node.input) == len(inputs))
48+
feed_dict_raw = dict(zip(node.input, inputs))
49+
50+
input_dict = dict(map(lambda x: (x[0], tf.constant(x[1])), feed_dict_raw.items()))
51+
print(input_dict)
52+
outputs = cls._onnx_node_to_tensorflow_op(node, input_dict)
53+
output_dict = {}
54+
with tf.Session() as sess:
55+
with tf.device(device_option):
56+
for key, val in outputs.items():
57+
output_dict[key] = sess.run(val)
58+
return output_dict
59+
60+
@classmethod
61+
def _onnx_node_to_tensorflow_op(cls, node, input_dict):
62+
def _merge_two_dicts(x, y):
63+
z = x.copy()
64+
z.update(y)
65+
return z
66+
output_dict = dict()
67+
method_to_call = getattr(cls, "handle_" + node.op_type.lower())
68+
output_dict = _merge_two_dicts(output_dict, method_to_call(node, input_dict))
69+
return output_dict
70+
71+
@classmethod
72+
def handle_relu(cls, node, input_dict):
73+
output_name = node.output[0]
74+
input_name = node.input[0]
75+
return dict([(output_name, tf.nn.relu(input_dict[input_name]))])
76+
77+
run_node = TensorflowBackend.run_node

scaffold/extract_tf_ops.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from __future__ import print_function
2+
import re
3+
import pprint
4+
import urllib2
5+
6+
def get_tf_defs():
7+
tf_op_defs = {}
8+
9+
def clean(name, input):
10+
return input.replace(name + "(\"", "").replace("\")", "").split(":")[0].strip();
11+
12+
url = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/core/ops/{}.cc"
13+
fnames = ["nn_ops", "math_ops", "array_ops"]
14+
15+
for fname in fnames:
16+
17+
content = urllib2.urlopen(url.format(fname)).read()
18+
content = re.sub("\([\s\n]+\"", "(\"", content)
19+
content = re.sub("\"[\s\n]+\"", "", content)
20+
content = content.split("\n")
21+
content = [x.strip() for x in content]
22+
23+
in_op_def = False
24+
curr_op = {}
25+
for line in content:
26+
if (in_op_def):
27+
if line.startswith(".Input"):
28+
name = clean(".Input", line)
29+
# print("\ti " + name)
30+
curr_op["i"].append(name)
31+
elif line.startswith(".Output"):
32+
name = clean(".Output", line)
33+
# print("\to " + name)
34+
curr_op["o"].append(name)
35+
elif line.startswith(".Attr"):
36+
name = clean(".Attr", line)
37+
# print("\ta " + name)
38+
curr_op["a"].append(name)
39+
else:
40+
in_op_def = False
41+
tf_op_defs[curr_op["n"]] = curr_op
42+
else:
43+
if (line.startswith("REGISTER_OP")):
44+
in_op_def = True
45+
name = clean("REGISTER_OP", line)
46+
curr_op = {
47+
"n": name,
48+
"i": [],
49+
"o": [],
50+
"a": []
51+
}
52+
# print(name)
53+
return tf_op_defs
54+
55+
tf_op_defs = get_tf_defs()
56+
pp = pprint.PrettyPrinter(indent=2)
57+
# pp.pprint(tf_op_defs)
58+
59+
from onnx import onnx_pb2, helper, defs
60+
node_def = helper.make_node("Relu", ["X"], ["Y"])
61+
print(node_def)
62+
63+
all_schemas = defs.get_all_schemas()
64+
# print(all_schemas)
65+
# for name in all_schemas:
66+
# # print(dir(all_schemas[name]))
67+
# print("i", all_schemas[name].input_desc)
68+
# print("o", all_schemas[name].output_desc)
69+
# print("a", all_schemas[name].attributes)
70+
all_schemas = dict(filter(lambda x: len(x[1].input_desc)==1 ,all_schemas.items()))
71+
all_schemas = dict(filter(lambda x: len(x[1].attributes)==0 or len(x[1].attributes)==1,all_schemas.items()))
72+
73+
intersection = (set(tf_op_defs.keys()).intersection(set(all_schemas.keys())))
74+
75+
for key in intersection:
76+
print(key)
77+
pp.pprint(tf_op_defs[key])
78+
print("i", all_schemas[key].input_desc)
79+
print("o", all_schemas[key].output_desc)
80+
print("a", all_schemas[key].attributes)
81+
82+
83+
# print(intersection)
84+
85+
86+

setup.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from setuptools import setup
2+
3+
setup(name='onnxtf',
4+
version='0.1',
5+
description='Tensorflow backend for ONNX',
6+
url='TBD',
7+
author='IBM',
8+
author_email='TBD',
9+
license='TBD',
10+
packages=['onnxtf'],
11+
zip_safe=False)

test/__init__.py

Whitespace-only changes.

test/test_op.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
import unittest
7+
import numpy as np
8+
from onnxtf.convert import run_node
9+
from onnx import helper
10+
11+
class TestStringMethods(unittest.TestCase):
12+
13+
def test_relu(self):
14+
node_def = helper.make_node("Relu", ["X"], ["Y"])
15+
input = np.random.uniform(-1,1,1000)
16+
output = run_node(node_def, [input])
17+
np.testing.assert_almost_equal(output["Y"], np.clip(input, 0, 1))
18+
19+
if __name__ == '__main__':
20+
unittest.main()

0 commit comments

Comments
 (0)