|
4 | 4 |
|
5 | 5 | from keras_core.api_export import keras_core_export |
6 | 6 | from keras_core.backend import KerasTensor |
| 7 | +from keras_core.backend.config import backend |
7 | 8 | from keras_core.ops.operation import Operation |
8 | 9 | from keras_core.utils.nest import pack_sequence_as |
9 | 10 |
|
@@ -46,10 +47,21 @@ class Function(Operation): |
46 | 47 | def __init__(self, inputs, outputs, name=None): |
47 | 48 | super().__init__(name=name) |
48 | 49 |
|
| 50 | + if backend() == "tensorflow": |
| 51 | + # Temporary work around for |
| 52 | + # https://github.com/keras-team/keras-core/issues/931 |
| 53 | + # This stop tensorflow from wrapping tf.function output in a |
| 54 | + # _DictWrapper object. |
| 55 | + _self_setattr_tracking = getattr( |
| 56 | + self, "_self_setattr_tracking", True |
| 57 | + ) |
| 58 | + self._self_setattr_tracking = False |
49 | 59 | self._inputs_struct = tree.map_structure(lambda x: x, inputs) |
50 | 60 | self._outputs_struct = tree.map_structure(lambda x: x, outputs) |
51 | 61 | self._inputs = tree.flatten(inputs) |
52 | 62 | self._outputs = tree.flatten(outputs) |
| 63 | + if backend() == "tensorflow": |
| 64 | + self._self_setattr_tracking = _self_setattr_tracking |
53 | 65 |
|
54 | 66 | (nodes, nodes_by_depth, operations, operations_by_depth) = map_graph( |
55 | 67 | self._inputs, self._outputs |
|
0 commit comments