Description
Description
I have a deterministic program that uses jax, and is heavy on linear algebra operations.
I ran this code on CPU, using three different CPUs. Two MacOs Systems (one on Sequoia (M1 Pro), other on Sonoma (M2)) and one on a linux system.
All three systems output different results for the same output, however they output that output deterministically.
Minimal Reproducible example
import jax
import optax
import flax.linen as nn
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
variables = jnp.array([0.1, -3 * jnp.pi / 2])
class RNN(nn.Module):
hidden_size: int
output_size: int
@nn.compact
def __call__(self, input, hidden_state):
gru_cell = nn.GRUCell(features=self.hidden_size)
new_hidden_state, _ = gru_cell(hidden_state, input)
output = nn.Dense(features=self.output_size)(new_hidden_state)
return output, new_hidden_state
def _optimize(
loss_fn,
init_params,
max_iter,
learning_rate,
):
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(init_params)
@jax.jit
def step(params, state):
grads = jax.grad(loss_fn)(params)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_state
params = init_params
for iter_idx in range(max_iter):
params, opt_state = step(params, opt_state)
return params, iter_idx + 1
def fun(gamma, delta):
op = jnp.array([[0, -1j], [1j, 0]])
angle = (gamma * op) + delta / 2
return (jax.scipy.linalg.expm(1j * angle) + jax.scipy.linalg.expm(-1j * angle)) / 2
def loss(params):
rnn = RNN(hidden_size=10, output_size=2)
input = variables
hidden_state = jnp.zeros((10,))
output, _ = rnn.apply({'params': params}, input, hidden_state)
params_out = output
return jnp.real(jnp.trace(fun(params_out[0], params_out[1])))
if __name__ == "__main__":
rng = jax.random.PRNGKey(0)
rnn = RNN(hidden_size=10, output_size=2)
input = variables
hidden_state = jnp.zeros((10,))
params = rnn.init(rng, input, hidden_state)['params']
max_iter = 100
learning_rate = 0.01
convergence_threshold = 1e-6
optimized_params, num_iterations = _optimize(
loss,
params,
max_iter,
learning_rate,
)
final_loss = loss(optimized_params)
print("Final Loss:", final_loss)
This outputs on a macos system:
-1.9979573829398634
and on a linux system:
-1.9979573808129485
Differing in the last 8 digits, and I suppose given a much larger complicated system this difference can be quite large.
In Machine Learning Applications those two difference in numerical outputs can significantly lead to convergence at difference minima, if the convergence is not so straight forwards
Checklist
- I've included a minimal example to reproduce the issue
- I'd be willing to make a PR to solve this issue