Skip to content

Commit 44c85da

Browse files
authored
keras 3 Adam optimizer (#2712)
1 parent d8acfa8 commit 44c85da

File tree

4 files changed

+276
-1
lines changed

4 files changed

+276
-1
lines changed

itex/python/experimental_ops_override_k3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from intel_extension_for_tensorflow.python.ops.layer_norm_k3 import _layer_norm
3232
from intel_extension_for_tensorflow.python.ops.group_norm_k3 import GroupNormalization
33+
from intel_extension_for_tensorflow.python.ops.optimizers_k3 import Adam
3334

3435
format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
3536
logging.basicConfig(level=logging.INFO, format=format_str)
@@ -371,6 +372,7 @@ def itex_var(x, axis=None, keepdims=False):
371372
keras.layers.BatchNormalization.build = itex_batch_norm_build
372373
keras.layers.GroupNormalization.call = GroupNormalization.call
373374
keras.layers.GroupNormalization.build = GroupNormalization.build
375+
keras.optimizers.Adam.update_step = Adam.update_step
374376

375377
except BaseException: # pylint: disable=broad-except
376378
logger.error("Cannot override itex ops.")
@@ -384,6 +386,7 @@ def itex_var(x, axis=None, keepdims=False):
384386
keras.src.layers.normalization.group_normalization.GroupNormalization.build = GroupNormalization.build
385387
keras.src.backend.numpy.mean = itex_mean
386388
keras.src.backend.numpy.var = itex_var
389+
keras.src.optimizers.adam.Adam.update_step = Adam.update_step
387390
logger.info("itex experimental ops override is enabled.")
388391
except BaseException: # pylint: disable=broad-except
389392
logger.warning(

itex/python/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from intel_extension_for_tensorflow.python.ops.activations import gelu
2020
from intel_extension_for_tensorflow.python.ops.rotary_embedding import qk_rotary_positional_embedding
2121
from intel_extension_for_tensorflow.python.ops import ops_grad as _ops_grad
22-
from intel_extension_for_tensorflow.python.ops.optimizers import AdamWithWeightDecayOptimizer, AdamWithWeightDecayLegacyOptimizer, LAMBOptimizer
2322

2423
from intel_extension_for_tensorflow.python.ops.multi_head_attention import scaled_dot_product_attention
2524

@@ -29,8 +28,10 @@
2928
from intel_extension_for_tensorflow.python.ops.mlp import FusedDenseBiasAddGelu
3029
from intel_extension_for_tensorflow.python.ops.rms_norm import RMSNormalization
3130
from intel_extension_for_tensorflow.python.ops.recurrent import ItexLSTM
31+
from intel_extension_for_tensorflow.python.ops.optimizers import AdamWithWeightDecayOptimizer, AdamWithWeightDecayLegacyOptimizer, LAMBOptimizer
3232
else:
3333
from intel_extension_for_tensorflow.python.ops.layer_norm_k3 import LayerNormalization
3434
from intel_extension_for_tensorflow.python.ops.group_norm_k3 import GroupNormalization
3535
from intel_extension_for_tensorflow.python.ops.mlp_k3 import Dense as FusedDenseBiasAddGelu
36+
from intel_extension_for_tensorflow.python.ops.optimizers_k3 import Adam
3637
from intel_extension_for_tensorflow.python.ops.rms_norm_k3 import RMSNormalization

itex/python/ops/optimizers_k3.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) 2021 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Adam for TensorFlow."""
16+
17+
18+
from intel_extension_for_tensorflow.python.ops.load_ops_library import load_ops_library
19+
from keras import ops
20+
from keras.src.optimizers import optimizer
21+
from keras.src.saving import object_registration
22+
from keras.src.backend.common import KerasVariable
23+
24+
import tensorflow as tf
25+
26+
27+
@object_registration.register_keras_serializable(package="Itex")
28+
class Adam(optimizer.Optimizer):
29+
"""Optimizer that implements the Adam algorithm.
30+
31+
Adam optimization is a stochastic gradient descent method that is based on
32+
adaptive estimation of first-order and second-order moments.
33+
34+
According to
35+
[Kingma et al., 2014](http://arxiv.org/abs/1412.6980),
36+
the method is "*computationally
37+
efficient, has little memory requirement, invariant to diagonal rescaling of
38+
gradients, and is well suited for problems that are large in terms of
39+
data/parameters*".
40+
41+
Args:
42+
learning_rate: A float, a
43+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
44+
a callable that takes no arguments and returns the actual value to
45+
use. The learning rate. Defaults to `0.001`.
46+
beta_1: A float value or a constant float tensor, or a callable
47+
that takes no arguments and returns the actual value to use. The
48+
exponential decay rate for the 1st moment estimates. Defaults to
49+
`0.9`.
50+
beta_2: A float value or a constant float tensor, or a callable
51+
that takes no arguments and returns the actual value to use. The
52+
exponential decay rate for the 2nd moment estimates. Defaults to
53+
`0.999`.
54+
epsilon: A small constant for numerical stability. This epsilon is
55+
"epsilon hat" in the Kingma and Ba paper (in the formula just before
56+
Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults
57+
to `1e-7`.
58+
amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm
59+
from the paper "On the Convergence of Adam and beyond". Defaults
60+
to `False`.
61+
{{base_optimizer_keyword_args}}
62+
"""
63+
64+
def __init__(
65+
self,
66+
learning_rate=0.001,
67+
beta_1=0.9,
68+
beta_2=0.999,
69+
epsilon=1e-7,
70+
amsgrad=False,
71+
weight_decay=None,
72+
clipnorm=None,
73+
clipvalue=None,
74+
global_clipnorm=None,
75+
use_ema=False,
76+
ema_momentum=0.99,
77+
ema_overwrite_frequency=None,
78+
loss_scale_factor=None,
79+
gradient_accumulation_steps=None,
80+
name="adam",
81+
**kwargs,
82+
):
83+
super().__init__(
84+
learning_rate=learning_rate,
85+
name=name,
86+
weight_decay=weight_decay,
87+
clipnorm=clipnorm,
88+
clipvalue=clipvalue,
89+
global_clipnorm=global_clipnorm,
90+
use_ema=use_ema,
91+
ema_momentum=ema_momentum,
92+
ema_overwrite_frequency=ema_overwrite_frequency,
93+
loss_scale_factor=loss_scale_factor,
94+
gradient_accumulation_steps=gradient_accumulation_steps,
95+
**kwargs,
96+
)
97+
self.beta_1 = beta_1
98+
self.beta_2 = beta_2
99+
self.epsilon = epsilon
100+
self.amsgrad = amsgrad
101+
102+
def build(self, var_list):
103+
"""Initialize optimizer variables.
104+
105+
Adam optimizer has 3 types of variables: momentums, velocities and
106+
velocity_hat (only set when amsgrad is applied),
107+
108+
Args:
109+
var_list: list of model variables to build Adam variables on.
110+
"""
111+
if self.built:
112+
return
113+
super().build(var_list)
114+
self._momentums = []
115+
self._velocities = []
116+
for var in var_list:
117+
self._momentums.append(
118+
self.add_variable_from_reference(
119+
reference_variable=var, name="momentum"
120+
)
121+
)
122+
self._velocities.append(
123+
self.add_variable_from_reference(
124+
reference_variable=var, name="velocity"
125+
)
126+
)
127+
if self.amsgrad:
128+
self._velocity_hats = []
129+
for var in var_list:
130+
self._velocity_hats.append(
131+
self.add_variable_from_reference(
132+
reference_variable=var, name="velocity_hat"
133+
)
134+
)
135+
136+
def update_step(self, gradient, variable, learning_rate):
137+
"""Update step given gradient and the associated model variable."""
138+
lr = ops.cast(learning_rate, variable.dtype)
139+
gradient = ops.cast(gradient, variable.dtype)
140+
local_step = ops.cast(self.iterations + 1, variable.dtype)
141+
beta_1_power = ops.power(
142+
ops.cast(self.beta_1, variable.dtype), local_step
143+
)
144+
beta_2_power = ops.power(
145+
ops.cast(self.beta_2, variable.dtype), local_step
146+
)
147+
148+
m = self._momentums[self._get_variable_index(variable)]
149+
v = self._velocities[self._get_variable_index(variable)]
150+
151+
if len(tf.config.list_physical_devices("XPU")) > 0 and isinstance(m, KerasVariable) and isinstance(v, KerasVariable) and isinstance(variable, tf.Variable):
152+
if isinstance(m.value, tf.Variable) and isinstance(v.value, tf.Variable) and isinstance(gradient, tf.Tensor):
153+
if self.amsgrad:
154+
v_hat = self._velocity_hats[self._get_variable_index(
155+
variable)]
156+
else:
157+
v_hat = v # just a placeholder
158+
return load_ops_library.itex_resource_apply_adam_with_weight_decay(
159+
variable.handle,
160+
m.value.handle,
161+
v.value.handle,
162+
beta_1_power,
163+
beta_2_power,
164+
lr,
165+
ops.cast(self.beta_1, variable.dtype),
166+
ops.cast(self.beta_2, variable.dtype),
167+
ops.cast(self.epsilon, variable.dtype),
168+
ops.cast(0.0, variable.dtype),
169+
v_hat.value.handle,
170+
gradient,
171+
use_locking=False,
172+
use_amsgrad=self.amsgrad)
173+
174+
alpha = lr * ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)
175+
176+
self.assign_add(
177+
m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1)
178+
)
179+
self.assign_add(
180+
v,
181+
ops.multiply(
182+
ops.subtract(ops.square(gradient), v), 1 - self.beta_2
183+
),
184+
)
185+
if self.amsgrad:
186+
v_hat = self._velocity_hats[self._get_variable_index(variable)]
187+
self.assign(v_hat, ops.maximum(v_hat, v))
188+
v = v_hat
189+
self.assign_sub(
190+
variable,
191+
ops.divide(
192+
ops.multiply(m, alpha), ops.add(ops.sqrt(v), self.epsilon)
193+
),
194+
)
195+
196+
def get_config(self):
197+
config = super().get_config()
198+
config.update(
199+
{
200+
"beta_1": self.beta_1,
201+
"beta_2": self.beta_2,
202+
"epsilon": self.epsilon,
203+
"amsgrad": self.amsgrad,
204+
}
205+
)
206+
return config
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2023 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import os
17+
os.environ['TF_USE_LEGACY_KERAS']='0'
18+
os.environ["ITEX_DISABLE_XLA"]="1"
19+
20+
import intel_extension_for_tensorflow as itex
21+
import numpy as np
22+
import tensorflow as tf
23+
from tensorflow.python.framework import dtypes
24+
from tensorflow.python.framework import constant_op
25+
26+
from intel_extension_for_tensorflow.python.test_func import test_util
27+
from intel_extension_for_tensorflow.python.test_func import test
28+
29+
class AdamTest(test_util.TensorFlowTestCase):
30+
"""test AdamW op"""
31+
def testAdam(self):
32+
# Initialize variables for numpy implementation and Create Tensorflow variables.
33+
size = [2,4,3]
34+
dtype = dtypes.float32
35+
var0 = np.random.normal(size=size)
36+
var1 = np.random.normal(size=size)
37+
grads0 = np.random.normal(size=size)
38+
grads1 = np.random.normal(size=size)
39+
# tf
40+
tf_var0 = tf.Variable(var0, dtype=dtype)
41+
tf_var1 = tf.Variable(var1, dtype=dtype)
42+
tf_grads0 = constant_op.constant(grads0, dtype=dtype)
43+
tf_grads1 = constant_op.constant(grads1, dtype=dtype)
44+
tf_adam = tf.keras.optimizers.Adam(weight_decay=0.04, learning_rate=0.01)
45+
for _ in range(3): # Run 3 steps of the optimizer
46+
tf_adam.apply_gradients(
47+
zip([tf_grads0, tf_grads1], [tf_var0, tf_var1])
48+
)
49+
# itex
50+
itex.experimental_ops_override()
51+
itex_var0 = tf.Variable(var0, dtype=dtype)
52+
itex_var1 = tf.Variable(var1, dtype=dtype)
53+
itex_grads0 = constant_op.constant(grads0, dtype=dtype)
54+
itex_grads1 = constant_op.constant(grads1, dtype=dtype)
55+
itex_adam = tf.keras.optimizers.Adam(weight_decay=0.04, learning_rate=0.01)
56+
for _ in range(3): # Run 3 steps of the optimizer
57+
itex_adam.apply_gradients(
58+
zip([itex_grads0, itex_grads1], [itex_var0, itex_var1])
59+
)
60+
# Validate updated parameters
61+
self.assertAllClose(tf_var0, itex_var0)
62+
self.assertAllClose(tf_var1, itex_var1)
63+
64+
if __name__ == "__main__":
65+
test.main()

0 commit comments

Comments
 (0)