# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from .block import BlockFormatInfo
from .types import FormatInfo, Domain, Signedness
import math
#: FormatInfo for IEEE-754 Binary64 format
format_info_binary64 = FormatInfo(
name="binary64",
k=64,
precision=53,
bias=2 ** (64 - 53 - 1) - 1,
has_nz=True,
domain=Domain.Extended,
num_high_nans=2**52 - 1,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
#: FormatInfo for IEEE-754 Binary32 format
format_info_binary32 = FormatInfo(
name="binary32",
k=32,
precision=24,
bias=2 ** (32 - 24 - 1) - 1,
has_nz=True,
domain=Domain.Extended,
num_high_nans=2**23 - 1,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
#: FormatInfo for IEEE-754 Binary16 format
format_info_binary16 = FormatInfo(
name="binary16",
k=16,
precision=11,
bias=2 ** (16 - 11 - 1) - 1,
has_nz=True,
domain=Domain.Extended,
num_high_nans=2**10 - 1,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
#: FormatInfo for Google BFloat16 format
format_info_bfloat16 = FormatInfo(
name="bfloat16",
k=16,
precision=8,
bias=2 ** (16 - 8 - 1) - 1,
has_nz=True,
domain=Domain.Extended,
num_high_nans=2**7 - 1,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
#: FormatInfo for OCP E5M2 format
format_info_ocp_e5m2 = FormatInfo(
name="ocp_e5m2",
k=8,
precision=3,
bias=2 ** (8 - 3 - 1) - 1,
has_nz=True,
domain=Domain.Extended,
num_high_nans=2**2 - 1,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
#: FormatInfo for OCP E4M3 format
format_info_ocp_e4m3 = FormatInfo(
name="ocp_e4m3",
k=8,
precision=4,
bias=2 ** (8 - 4 - 1) - 1,
has_nz=True,
domain=Domain.Finite,
num_high_nans=1,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
#: FormatInfo for OCP MX E2M3 format
format_info_ocp_e2m3 = FormatInfo(
name="ocp_e2m3",
k=6,
precision=4,
bias=2 ** (6 - 4 - 1) - 1,
has_nz=True,
domain=Domain.Finite,
num_high_nans=0,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
#: FormatInfo for OCP MX E3M2 format
format_info_ocp_e3m2 = FormatInfo(
name="ocp_e3m2",
k=6,
precision=3,
bias=2 ** (6 - 3 - 1) - 1,
has_nz=True,
domain=Domain.Finite,
num_high_nans=0,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
#: FormatInfo for OCP MX E2M1 format
format_info_ocp_e2m1 = FormatInfo(
name="ocp_e2m1",
k=4,
precision=2,
bias=2 ** (4 - 2 - 1) - 1,
has_nz=True,
domain=Domain.Finite,
num_high_nans=0,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
#: FormatInfo for OCP MX E8M0 format
format_info_ocp_e8m0 = FormatInfo(
name="ocp_e8m0",
k=8,
precision=1,
bias=2 ** (8 - 1) - 1,
has_nz=False,
domain=Domain.Finite,
num_high_nans=1,
has_subnormals=False,
is_signed=False,
is_twos_complement=False,
)
#: FormatInfo for OCP MX INT8 format
format_info_ocp_int8 = FormatInfo(
name="ocp_int8",
k=8,
precision=8,
bias=0,
has_nz=False,
domain=Domain.Finite,
num_high_nans=0,
has_subnormals=True,
is_signed=True,
is_twos_complement=True,
)
# Collections of formats
_tiny_formats = [
format_info_p3109(3, 2, Signedness.Signed, Domain.Finite),
format_info_ocp_e2m1,
format_info_p3109(4, 2, Signedness.Signed, Domain.Finite),
format_info_ocp_e2m3,
format_info_ocp_e3m2,
format_info_p3109(6, 3, Signedness.Signed, Domain.Finite),
format_info_p3109(6, 4, Signedness.Signed, Domain.Finite),
]
p3109_binary8_formats = (
[
format_info_p3109(8, 1, Signedness.Signed, Domain.Extended),
format_info_p3109(8, 1, Signedness.Unsigned, Domain.Extended),
]
+ [
format_info_p3109(8, p, signedness, domain)
for p in (3, 4)
for signedness in (Signedness.Signed, Signedness.Unsigned)
for domain in (Domain.Extended, Domain.Finite)
]
+ [
format_info_p3109(8, 7, Signedness.Signed, Domain.Finite),
format_info_p3109(8, 8, Signedness.Unsigned, Domain.Finite),
]
)
_fp8_formats = [
format_info_ocp_e4m3,
format_info_ocp_e5m2,
*p3109_binary8_formats,
]
_fp16_formats = [
format_info_binary16,
format_info_bfloat16,
]
sample_formats = [
*_tiny_formats,
*_fp8_formats,
*_fp16_formats,
format_info_binary32,
format_info_binary64,
format_info_ocp_e8m0,
format_info_ocp_int8,
]
# ------
# Block formats
format_info_mxfp8_e5m2 = BlockFormatInfo(
"mxfp8_e5m2", format_info_ocp_e5m2, 32, format_info_ocp_e8m0
)
format_info_mxfp8_e4m3 = BlockFormatInfo(
"mxfp8_e4m3", format_info_ocp_e4m3, 32, format_info_ocp_e8m0
)
format_info_mxfp6_e3m2 = BlockFormatInfo(
"mxfp6_e3m2", format_info_ocp_e3m2, 32, format_info_ocp_e8m0
)
format_info_mxfp6_e2m3 = BlockFormatInfo(
"mxfp6_e2m3", format_info_ocp_e2m3, 32, format_info_ocp_e8m0
)
format_info_mxfp4_e2m1 = BlockFormatInfo(
"mxfp4_e2m1", format_info_ocp_e2m1, 32, format_info_ocp_e8m0
)
format_info_mxfp4_e2m1 = BlockFormatInfo(
"mxfp4_e2m1", format_info_ocp_e2m1, 32, format_info_ocp_e8m0
)
format_info_mxint8 = BlockFormatInfo(
"mxint8", format_info_ocp_int8, 32, format_info_ocp_e8m0
)
all_block_formats = [
format_info_mxfp8_e5m2,
format_info_mxfp8_e4m3,
format_info_mxfp6_e3m2,
format_info_mxfp6_e2m3,
format_info_mxfp4_e2m1,
format_info_mxint8,
]