Skip to content

Commit 2b7af06

Browse files
samgizAlfie-Edwards
authored andcommitted
Add allowed values for While and ReadVariableOp parameter types.
Summary: This is necessary for keras fp8 layers to work properly. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, alfiee Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, alfiee Maniphest Tasks: T65638 Differential Revision: https://phabricator.sourcevertex.net/D78710
1 parent d21d1bf commit 2b7af06

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

tensorflow/compiler/plugin/poplar/driver/xla_ipu_common.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ bool OpFilter(KernelDef* kdef) {
7979
AddDtypeToKernelDefConstraint("U", DT_UINT8, kdef);
8080
}
8181

82+
if (kdef->op() == "While") {
83+
AddDtypeToKernelDefConstraint("T", DT_UINT8, kdef);
84+
}
85+
86+
if (kdef->op() == "ReadVariableOp") {
87+
AddDtypeToKernelDefConstraint("dtype", DT_UINT8, kdef);
88+
}
89+
8290
if (kdef->op() == "Const") {
8391
AddDtypeToKernelDefConstraint("dtype", DT_STRING, kdef);
8492
AddDtypeToKernelDefConstraint("dtype", DT_UINT8, kdef);

tensorflow/python/ipu/ops/f8_ops.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ def __init__(self, data, metadata):
6262
of a call to `create_metadata`.
6363
"""
6464
self.data = self._maybe_get_tf_variable(data)
65-
self.metadata = metadata
65+
self.metadata = self._maybe_get_tf_variable(metadata)
6666

6767
def _maybe_get_tf_variable(self, data):
68-
if isinstance(data, np.ndarray):
69-
data = tf.Variable(data)
68+
if isinstance(data, np.ndarray) or isinstance(data, int):
69+
data = ops.convert_to_tensor(data, dtype=dtypes.uint8)
7070
if data.dtype != "uint8":
7171
raise TypeError(
7272
"Trying to set/update QuarterTensor data with a tensor of type "
@@ -77,7 +77,7 @@ def _maybe_get_tf_variable(self, data):
7777
def numpy(self):
7878
"""Returns a numpy representation of the tensor
7979
"""
80-
return [self.data.numpy(), self.metadata]
80+
return [self.data.numpy(), self.metadata.numpy()]
8181

8282
def assign(self, new_values, **kwargs):
8383
"""Assigns new values to the tensor
@@ -87,7 +87,10 @@ def assign(self, new_values, **kwargs):
8787
should be the output of `QuarterTensor.numpy`.
8888
"""
8989
self.data = self._maybe_get_tf_variable(new_values[0])
90-
self.metadata = new_values[1]
90+
self.metadata = self._maybe_get_tf_variable(new_values[1])
91+
92+
def __repr__(self):
93+
return f"QuarterTensor(data={self.data}, metadata={self.metadata})"
9194

9295
def __eq__(self, other):
9396
return self.numpy() == other.numpy()
@@ -127,8 +130,7 @@ def convert_to_f8(values, metadata, name=None):
127130
if not (values.dtype is dtypes.float16):
128131
values = cast(values, dtypes.float16)
129132

130-
if isinstance(metadata, np.ndarray) or isinstance(metadata, int):
131-
metadata = ops.convert_to_tensor(metadata, dtype=dtypes.uint8)
133+
metadata = ops.convert_to_tensor(metadata, dtype=dtypes.uint8)
132134
v, m = gen_popops_ops.ipu_convert_to_f8(values, metadata, name=name)
133135
return QuarterTensor(v, m)
134136

@@ -147,8 +149,8 @@ def convert_from_f8(packed_input, dtype=dtypes.half, name=None):
147149
Tensor with type dtype with unpacked f8 values.
148150
"""
149151
values, metadata = packed_input
150-
if isinstance(metadata, np.ndarray) or isinstance(metadata, int):
151-
metadata = ops.convert_to_tensor(metadata, dtype=dtypes.uint8)
152+
153+
metadata = ops.convert_to_tensor(metadata, dtype=dtypes.uint8)
152154
values = gen_popops_ops.ipu_convert_from_f8(values, metadata, name=name)
153155
if values.dtype != dtype:
154156
values = cast(values, dtype=dtype)

0 commit comments

Comments
 (0)