Skip to content

LAPACK Inconsistent across multiple different operating systems and devices #1137

Open
@YousefElbrolosy

Description

@YousefElbrolosy

Description
I have a deterministic program that uses jax, and is heavy on linear algebra operations.

I ran this code on CPU, using three different CPUs. Two MacOs Systems (one on Sequoia (M1 Pro), other on Sonoma (M2)) and one on a linux system.

All three systems output different results for the same output, however they output that output deterministically.

Minimal Reproducible example

import jax
import optax
import flax.linen as nn
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

variables = jnp.array([0.1, -3 * jnp.pi / 2])


class RNN(nn.Module):
    hidden_size: int
    output_size: int

    @nn.compact
    def __call__(self, input, hidden_state):
        gru_cell = nn.GRUCell(features=self.hidden_size)
        new_hidden_state, _ = gru_cell(hidden_state, input)
        output = nn.Dense(features=self.output_size)(new_hidden_state)
        return output, new_hidden_state

def _optimize(
    loss_fn,
    init_params,
    max_iter,
    learning_rate,
):
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(init_params)

    @jax.jit
    def step(params, state):
        grads = jax.grad(loss_fn)(params)
        updates, new_state = optimizer.update(grads, state, params)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_state

    params = init_params
    for iter_idx in range(max_iter):
        params, opt_state = step(params, opt_state)
    return params, iter_idx + 1

def fun(gamma, delta):
    op = jnp.array([[0, -1j], [1j, 0]])
    angle = (gamma * op) + delta / 2
    return (jax.scipy.linalg.expm(1j * angle) + jax.scipy.linalg.expm(-1j * angle)) / 2

def loss(params):
    rnn = RNN(hidden_size=10, output_size=2)
    input = variables
    hidden_state = jnp.zeros((10,))
    output, _ = rnn.apply({'params': params}, input, hidden_state)
    params_out = output
    return jnp.real(jnp.trace(fun(params_out[0], params_out[1])))

if __name__ == "__main__":
    rng = jax.random.PRNGKey(0)
    rnn = RNN(hidden_size=10, output_size=2)
    input = variables
    hidden_state = jnp.zeros((10,))
    params = rnn.init(rng, input, hidden_state)['params']

    max_iter = 100
    learning_rate = 0.01
    convergence_threshold = 1e-6
    optimized_params, num_iterations = _optimize(
        loss,
        params,
        max_iter,
        learning_rate,
    )
    final_loss = loss(optimized_params)
    print("Final Loss:", final_loss)

This outputs on a macos system:

-1.9979573829398634

and on a linux system:

-1.9979573808129485

Differing in the last 8 digits, and I suppose given a much larger complicated system this difference can be quite large.

In Machine Learning Applications those two difference in numerical outputs can significantly lead to convergence at difference minima, if the convergence is not so straight forwards

Checklist

  • I've included a minimal example to reproduce the issue
  • I'd be willing to make a PR to solve this issue

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions