# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
import numpy as np
import jax
import jax.numpy as jnp
import ml_dtypes
import gfloat
from gfloat.formats import format_info_ocp_e5m2
from timeit import Timer
jax.config.update("jax_enable_x64", True)
Timing tests
The gfloat library is designed for readability over performance, and the reference code for computations is the (slow) scalar code e.g. round_float.
There are vectorized implementations (e.g. round_ndarray), and when combined with JAX, these can go reasonably fast.
Let’s see how long it takes to encode some values to FP8…
# NBVAL_IGNORE_OUTPUT
N = 1_000_000
a = np.random.rand(N)
jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x))
ja = jnp.array(a)
jax_round_jit(ja) # Cache compilation
def slow_round_ndarray(fi, a):
return np.array([gfloat.round_float(fi, x) for x in a])
# About how many seconds to run for (autorange will take at least .2 sec)
ACCURACY = 1.0
def time(f, problem_size=1.0):
units = 1e9 # nsec
t = Timer(f)
f() # pre-run
n = int(t.autorange()[0] * ACCURACY / 0.2)
ts = t.repeat(repeat=3, number=n) # best of 3
ts = [((t / n) / problem_size) * units for t in ts] # per run
return f"{min(ts):8.2f} nsec ({n} runs at size {problem_size})"
# fmt: off
print("GFloat scalar :", time(lambda: slow_round_ndarray(format_info_ocp_e5m2, a[: N // 100]), N // 100))
print("GFloat vectorized, numpy arrays:", time(lambda: gfloat.round_ndarray(format_info_ocp_e5m2, a), N))
print("GFloat vectorized, JAX JIT :", time(lambda: jax_round_jit(ja), N))
print("ML_dtypes :", time(lambda: a.astype(ml_dtypes.float8_e5m2), N))
WARNING:2025-08-20 15:40:01,949:jax._src.xla_bridge:872: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
GFloat scalar : 2605.38 nsec (50 runs at size 10000)
GFloat vectorized, numpy arrays: 50.20 nsec (25 runs at size 1000000)
GFloat vectorized, JAX JIT : 3.79 nsec (500 runs at size 1000000)
ML_dtypes : 2.60 nsec (500 runs at size 1000000)
On one CPU platform the timings were:
GFloat scalar : 6996.75 nsec (50 runs at size 10000)
GFloat vectorized, numpy arrays: 75.04 nsec (50 runs at size 1000000)
GFloat vectorized, JAX JIT : 3.18 nsec (1000 runs at size 1000000)
ML_dtypes : 3.13 nsec (1000 runs at size 1000000)
So the JAX JIT code is ~1000x faster than the scalar code, and comparable to ml_dtypes’s C++ CPU implementation.