# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

import numpy as np
import jax
import jax.numpy as jnp
import ml_dtypes
import gfloat
from gfloat.formats import format_info_ocp_e5m2
from timeit import Timer

jax.config.update("jax_enable_x64", True)

Timing tests

The gfloat library is designed for readability over performance, and the reference code for computations is the (slow) scalar code e.g. round_float.

There are vectorized implementations (e.g. round_ndarray), and when combined with JAX, these can go reasonably fast.

Let’s see how long it takes to encode some values to FP8…

# NBVAL_IGNORE_OUTPUT

N = 1_000_000
a = np.random.rand(N)

jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x))
ja = jnp.array(a)
jax_round_jit(ja)  # Cache compilation


def slow_round_ndarray(fi, a):
    return np.array([gfloat.round_float(fi, x) for x in a])


# About how many seconds to run for (autorange will take at least .2 sec)
ACCURACY = 1.0


def time(f, problem_size=1.0):
    units = 1e9  # nsec
    t = Timer(f)
    f()  # pre-run
    n = int(t.autorange()[0] * ACCURACY / 0.2)
    ts = t.repeat(repeat=3, number=n)  # best of 3
    ts = [((t / n) / problem_size) * units for t in ts]  # per run
    return f"{min(ts):8.2f} nsec ({n} runs at size {problem_size})"


# fmt: off
print("GFloat scalar                  :", time(lambda: slow_round_ndarray(format_info_ocp_e5m2, a[: N // 100]), N // 100))
print("GFloat vectorized, numpy arrays:", time(lambda: gfloat.round_ndarray(format_info_ocp_e5m2, a), N))
print("GFloat vectorized, JAX JIT     :", time(lambda: jax_round_jit(ja), N))
print("ML_dtypes                      :", time(lambda: a.astype(ml_dtypes.float8_e5m2), N))
WARNING:2025-08-20 15:40:01,949:jax._src.xla_bridge:872: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
GFloat scalar                  :  2605.38 nsec (50 runs at size 10000)
GFloat vectorized, numpy arrays:    50.20 nsec (25 runs at size 1000000)
GFloat vectorized, JAX JIT     :     3.79 nsec (500 runs at size 1000000)
ML_dtypes                      :     2.60 nsec (500 runs at size 1000000)

On one CPU platform the timings were:

GFloat scalar                  :  6996.75 nsec (50 runs at size 10000)
GFloat vectorized, numpy arrays:    75.04 nsec (50 runs at size 1000000)
GFloat vectorized, JAX JIT     :     3.18 nsec (1000 runs at size 1000000)
ML_dtypes                      :     3.13 nsec (1000 runs at size 1000000)

So the JAX JIT code is ~1000x faster than the scalar code, and comparable to ml_dtypes’s C++ CPU implementation.