{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Copyright (c) 2024 Graphcore Ltd. All rights reserved.\n", "\n", "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", "import ml_dtypes\n", "import gfloat\n", "from gfloat.formats import format_info_ocp_e5m2\n", "from timeit import Timer\n", "\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Timing tests\n", "\n", "The `gfloat` library is designed for readability over performance, and the reference code for computations is the (slow) scalar code e.g. `round_float`.\n", "\n", "There are vectorized implementations (e.g. `round_ndarray`), and when combined with JAX, these can go reasonably fast.\n", "\n", "Let's see how long it takes to encode some values to FP8..." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:2025-08-20 15:40:01,949:jax._src.xla_bridge:872: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "GFloat scalar : 2605.38 nsec (50 runs at size 10000)\n", "GFloat vectorized, numpy arrays: 50.20 nsec (25 runs at size 1000000)\n", "GFloat vectorized, JAX JIT : 3.79 nsec (500 runs at size 1000000)\n", "ML_dtypes : 2.60 nsec (500 runs at size 1000000)\n" ] } ], "source": [ "# NBVAL_IGNORE_OUTPUT\n", "\n", "N = 1_000_000\n", "a = np.random.rand(N)\n", "\n", "jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x))\n", "ja = jnp.array(a)\n", "jax_round_jit(ja) # Cache compilation\n", "\n", "\n", "def slow_round_ndarray(fi, a):\n", " return np.array([gfloat.round_float(fi, x) for x in a])\n", "\n", "\n", "# About how many seconds to run for (autorange will take at least .2 sec)\n", "ACCURACY = 1.0\n", "\n", "\n", "def time(f, problem_size=1.0):\n", " units = 1e9 # nsec\n", " t = Timer(f)\n", " f() # pre-run\n", " n = int(t.autorange()[0] * ACCURACY / 0.2)\n", " ts = t.repeat(repeat=3, number=n) # best of 3\n", " ts = [((t / n) / problem_size) * units for t in ts] # per run\n", " return f\"{min(ts):8.2f} nsec ({n} runs at size {problem_size})\"\n", "\n", "\n", "# fmt: off\n", "print(\"GFloat scalar :\", time(lambda: slow_round_ndarray(format_info_ocp_e5m2, a[: N // 100]), N // 100))\n", "print(\"GFloat vectorized, numpy arrays:\", time(lambda: gfloat.round_ndarray(format_info_ocp_e5m2, a), N))\n", "print(\"GFloat vectorized, JAX JIT :\", time(lambda: jax_round_jit(ja), N))\n", "print(\"ML_dtypes :\", time(lambda: a.astype(ml_dtypes.float8_e5m2), N))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On one CPU platform the timings were:\n", "```\n", "GFloat scalar : 6996.75 nsec (50 runs at size 10000)\n", "GFloat vectorized, numpy arrays: 75.04 nsec (50 runs at size 1000000)\n", "GFloat vectorized, JAX JIT : 3.18 nsec (1000 runs at size 1000000)\n", "ML_dtypes : 3.13 nsec (1000 runs at size 1000000)\n", "```\n", "So the JAX JIT code is ~1000x faster than the scalar code, and comparable to `ml_dtypes`'s C++ CPU implementation." ] } ], "metadata": { "kernelspec": { "display_name": "gfloat", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }