|
| 1 | +# Copyright (c) 2021 Intel Corporation |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Adam for TensorFlow.""" |
| 16 | + |
| 17 | + |
| 18 | +from intel_extension_for_tensorflow.python.ops.load_ops_library import load_ops_library |
| 19 | +from keras import ops |
| 20 | +from keras.src.optimizers import optimizer |
| 21 | +from keras.src.saving import object_registration |
| 22 | +from keras.src.backend.common import KerasVariable |
| 23 | + |
| 24 | +import tensorflow as tf |
| 25 | + |
| 26 | + |
| 27 | +@object_registration.register_keras_serializable(package="Itex") |
| 28 | +class Adam(optimizer.Optimizer): |
| 29 | + """Optimizer that implements the Adam algorithm. |
| 30 | +
|
| 31 | + Adam optimization is a stochastic gradient descent method that is based on |
| 32 | + adaptive estimation of first-order and second-order moments. |
| 33 | +
|
| 34 | + According to |
| 35 | + [Kingma et al., 2014](http://arxiv.org/abs/1412.6980), |
| 36 | + the method is "*computationally |
| 37 | + efficient, has little memory requirement, invariant to diagonal rescaling of |
| 38 | + gradients, and is well suited for problems that are large in terms of |
| 39 | + data/parameters*". |
| 40 | +
|
| 41 | + Args: |
| 42 | + learning_rate: A float, a |
| 43 | + `keras.optimizers.schedules.LearningRateSchedule` instance, or |
| 44 | + a callable that takes no arguments and returns the actual value to |
| 45 | + use. The learning rate. Defaults to `0.001`. |
| 46 | + beta_1: A float value or a constant float tensor, or a callable |
| 47 | + that takes no arguments and returns the actual value to use. The |
| 48 | + exponential decay rate for the 1st moment estimates. Defaults to |
| 49 | + `0.9`. |
| 50 | + beta_2: A float value or a constant float tensor, or a callable |
| 51 | + that takes no arguments and returns the actual value to use. The |
| 52 | + exponential decay rate for the 2nd moment estimates. Defaults to |
| 53 | + `0.999`. |
| 54 | + epsilon: A small constant for numerical stability. This epsilon is |
| 55 | + "epsilon hat" in the Kingma and Ba paper (in the formula just before |
| 56 | + Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults |
| 57 | + to `1e-7`. |
| 58 | + amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm |
| 59 | + from the paper "On the Convergence of Adam and beyond". Defaults |
| 60 | + to `False`. |
| 61 | + {{base_optimizer_keyword_args}} |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__( |
| 65 | + self, |
| 66 | + learning_rate=0.001, |
| 67 | + beta_1=0.9, |
| 68 | + beta_2=0.999, |
| 69 | + epsilon=1e-7, |
| 70 | + amsgrad=False, |
| 71 | + weight_decay=None, |
| 72 | + clipnorm=None, |
| 73 | + clipvalue=None, |
| 74 | + global_clipnorm=None, |
| 75 | + use_ema=False, |
| 76 | + ema_momentum=0.99, |
| 77 | + ema_overwrite_frequency=None, |
| 78 | + loss_scale_factor=None, |
| 79 | + gradient_accumulation_steps=None, |
| 80 | + name="adam", |
| 81 | + **kwargs, |
| 82 | + ): |
| 83 | + super().__init__( |
| 84 | + learning_rate=learning_rate, |
| 85 | + name=name, |
| 86 | + weight_decay=weight_decay, |
| 87 | + clipnorm=clipnorm, |
| 88 | + clipvalue=clipvalue, |
| 89 | + global_clipnorm=global_clipnorm, |
| 90 | + use_ema=use_ema, |
| 91 | + ema_momentum=ema_momentum, |
| 92 | + ema_overwrite_frequency=ema_overwrite_frequency, |
| 93 | + loss_scale_factor=loss_scale_factor, |
| 94 | + gradient_accumulation_steps=gradient_accumulation_steps, |
| 95 | + **kwargs, |
| 96 | + ) |
| 97 | + self.beta_1 = beta_1 |
| 98 | + self.beta_2 = beta_2 |
| 99 | + self.epsilon = epsilon |
| 100 | + self.amsgrad = amsgrad |
| 101 | + |
| 102 | + def build(self, var_list): |
| 103 | + """Initialize optimizer variables. |
| 104 | +
|
| 105 | + Adam optimizer has 3 types of variables: momentums, velocities and |
| 106 | + velocity_hat (only set when amsgrad is applied), |
| 107 | +
|
| 108 | + Args: |
| 109 | + var_list: list of model variables to build Adam variables on. |
| 110 | + """ |
| 111 | + if self.built: |
| 112 | + return |
| 113 | + super().build(var_list) |
| 114 | + self._momentums = [] |
| 115 | + self._velocities = [] |
| 116 | + for var in var_list: |
| 117 | + self._momentums.append( |
| 118 | + self.add_variable_from_reference( |
| 119 | + reference_variable=var, name="momentum" |
| 120 | + ) |
| 121 | + ) |
| 122 | + self._velocities.append( |
| 123 | + self.add_variable_from_reference( |
| 124 | + reference_variable=var, name="velocity" |
| 125 | + ) |
| 126 | + ) |
| 127 | + if self.amsgrad: |
| 128 | + self._velocity_hats = [] |
| 129 | + for var in var_list: |
| 130 | + self._velocity_hats.append( |
| 131 | + self.add_variable_from_reference( |
| 132 | + reference_variable=var, name="velocity_hat" |
| 133 | + ) |
| 134 | + ) |
| 135 | + |
| 136 | + def update_step(self, gradient, variable, learning_rate): |
| 137 | + """Update step given gradient and the associated model variable.""" |
| 138 | + lr = ops.cast(learning_rate, variable.dtype) |
| 139 | + gradient = ops.cast(gradient, variable.dtype) |
| 140 | + local_step = ops.cast(self.iterations + 1, variable.dtype) |
| 141 | + beta_1_power = ops.power( |
| 142 | + ops.cast(self.beta_1, variable.dtype), local_step |
| 143 | + ) |
| 144 | + beta_2_power = ops.power( |
| 145 | + ops.cast(self.beta_2, variable.dtype), local_step |
| 146 | + ) |
| 147 | + |
| 148 | + m = self._momentums[self._get_variable_index(variable)] |
| 149 | + v = self._velocities[self._get_variable_index(variable)] |
| 150 | + |
| 151 | + if len(tf.config.list_physical_devices("XPU")) > 0 and isinstance(m, KerasVariable) and isinstance(v, KerasVariable) and isinstance(variable, tf.Variable): |
| 152 | + if isinstance(m.value, tf.Variable) and isinstance(v.value, tf.Variable) and isinstance(gradient, tf.Tensor): |
| 153 | + if self.amsgrad: |
| 154 | + v_hat = self._velocity_hats[self._get_variable_index( |
| 155 | + variable)] |
| 156 | + else: |
| 157 | + v_hat = v # just a placeholder |
| 158 | + return load_ops_library.itex_resource_apply_adam_with_weight_decay( |
| 159 | + variable.handle, |
| 160 | + m.value.handle, |
| 161 | + v.value.handle, |
| 162 | + beta_1_power, |
| 163 | + beta_2_power, |
| 164 | + lr, |
| 165 | + ops.cast(self.beta_1, variable.dtype), |
| 166 | + ops.cast(self.beta_2, variable.dtype), |
| 167 | + ops.cast(self.epsilon, variable.dtype), |
| 168 | + ops.cast(0.0, variable.dtype), |
| 169 | + v_hat.value.handle, |
| 170 | + gradient, |
| 171 | + use_locking=False, |
| 172 | + use_amsgrad=self.amsgrad) |
| 173 | + |
| 174 | + alpha = lr * ops.sqrt(1 - beta_2_power) / (1 - beta_1_power) |
| 175 | + |
| 176 | + self.assign_add( |
| 177 | + m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1) |
| 178 | + ) |
| 179 | + self.assign_add( |
| 180 | + v, |
| 181 | + ops.multiply( |
| 182 | + ops.subtract(ops.square(gradient), v), 1 - self.beta_2 |
| 183 | + ), |
| 184 | + ) |
| 185 | + if self.amsgrad: |
| 186 | + v_hat = self._velocity_hats[self._get_variable_index(variable)] |
| 187 | + self.assign(v_hat, ops.maximum(v_hat, v)) |
| 188 | + v = v_hat |
| 189 | + self.assign_sub( |
| 190 | + variable, |
| 191 | + ops.divide( |
| 192 | + ops.multiply(m, alpha), ops.add(ops.sqrt(v), self.epsilon) |
| 193 | + ), |
| 194 | + ) |
| 195 | + |
| 196 | + def get_config(self): |
| 197 | + config = super().get_config() |
| 198 | + config.update( |
| 199 | + { |
| 200 | + "beta_1": self.beta_1, |
| 201 | + "beta_2": self.beta_2, |
| 202 | + "epsilon": self.epsilon, |
| 203 | + "amsgrad": self.amsgrad, |
| 204 | + } |
| 205 | + ) |
| 206 | + return config |
0 commit comments