Skip to content

Commit 9a8da03

Browse files
committed
Fix numpy backend
1 parent 18a174b commit 9a8da03

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

examples/demo_functional.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from keras_core import losses
66
from keras_core import metrics
77
from keras_core import optimizers
8+
import keras_core as keras
9+
10+
keras.config.disable_traceback_filtering()
811

912
inputs = layers.Input((100,))
1013
x = layers.Dense(512, activation="relu")(inputs)

keras_core/backend/exports.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from keras_core import backend
2+
from keras_core.api_export import keras_core_export
3+
4+
if backend.backend() == "tensorflow":
5+
BackendVariable = backend.tensorflow.core.Variable
6+
elif backend.backend() == "jax":
7+
BackendVariable = backend.jax.core.Variable
8+
elif backend.backend() == "torch":
9+
BackendVariable = backend.torch.core.Variable
10+
elif backend.backend() == "numpy":
11+
from keras_core.backend.numpy.core import Variable as NumpyVariable
12+
13+
BackendVariable = NumpyVariable
14+
else:
15+
raise RuntimeError("Invalid backend.")
16+
17+
18+
@keras_core_export("keras_core.backend.Variable")
19+
class Variable(BackendVariable):
20+
pass

keras_core/legacy/saving/json_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def get_json_type(obj):
184184
# return obj.__wrapped__
185185

186186
if tf.available and isinstance(obj, tf.TypeSpec):
187-
188187
from tensorflow.python.framework import type_spec_registry
189188

190189
try:

0 commit comments

Comments
 (0)