mlx/python/tests/test_bf16.py

197 lines
6.4 KiB
Python

# Copyright © 2023 Apple Inc.
import math
import unittest
from itertools import permutations
import mlx.core as mx
import mlx_tests
import numpy as np
try:
import torch
has_torch = True
except ImportError as e:
has_torch = False
class TestBF16(mlx_tests.MLXTestCase):
def __test_ops(
self,
ref_op, # Function that outputs array_like
mlx_op, # Function that outputs array_like
np_args, # Numpy arguments
ref_transform=lambda x: x,
mlx_transform=lambda x: mx.array(x),
atol=1e-5,
):
ref_args = map(ref_transform, np_args)
mlx_args = map(mlx_transform, np_args)
r_ref = ref_op(*ref_args)
r_mlx = mlx_op(*mlx_args)
self.assertTrue(np.allclose(r_mlx, r_ref, atol=atol))
def __default_test(
self,
op,
np_args,
simple_transform=lambda x: x,
atol_np=1e-3,
atol_torch=1e-5,
np_kwargs=dict(),
mlx_kwargs=dict(),
torch_kwargs=dict(),
torch_op=None,
):
with self.subTest(reference="numpy"):
def np_transform(x):
x_mx_bf16 = mx.array(x).astype(mx.bfloat16)
x_mx_fp32 = x_mx_bf16.astype(mx.float32)
return np.asarray(x_mx_fp32)
def mlx_fn(*args):
out_bf16 = getattr(mx, op)(*args, **mlx_kwargs)
return np.asarray(out_bf16.astype(mx.float32))
def np_fn(*args):
out_fp32 = getattr(np, op)(*args, **np_kwargs)
return np_transform(out_fp32)
ref_op = np_fn
mlx_op = mlx_fn
ref_transform = lambda x: simple_transform(np_transform(x))
mlx_transform = lambda x: simple_transform(mx.array(x).astype(mx.bfloat16))
self.__test_ops(
ref_op,
mlx_op,
np_args,
ref_transform=ref_transform,
mlx_transform=mlx_transform,
atol=atol_np,
)
if has_torch:
with self.subTest(reference="torch"):
torch_op = op if torch_op is None else torch_op
def torch_fn(*args):
out_bf16 = getattr(torch, torch_op)(*args, **torch_kwargs)
return out_bf16.to(torch.float32).numpy()
ref_op = torch_fn
ref_transform = lambda x: simple_transform(
torch.from_numpy(x).to(torch.bfloat16)
)
self.__test_ops(
ref_op,
mlx_op,
np_args,
ref_transform=ref_transform,
mlx_transform=mlx_transform,
atol=atol_torch,
)
def test_unary_ops(self):
x = np.random.rand(18, 28, 38)
for op in ["abs", "exp", "log", "square", "sqrt"]:
with self.subTest(op=op):
np_args = (x.astype(np.float32),)
self.__default_test(op, np_args)
def test_binary_ops(self):
x = np.random.rand(18, 28, 38)
y = np.random.rand(18, 28, 38)
for op in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]:
with self.subTest(op=op):
np_args = (
x.astype(np.float32),
y.astype(np.float32),
)
self.__default_test(op, np_args, simple_transform=lambda x: x)
self.__default_test(op, np_args, simple_transform=lambda x: x[:1])
self.__default_test(op, np_args, simple_transform=lambda x: x[:, :1])
def test_reduction_ops(self):
x = np.random.rand(18, 28, 38).astype(np.float32)
for op in ("min", "max"):
with self.subTest(op=op):
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
with self.subTest(axes=axes):
np_args = (x.astype(np.float32),)
self.__default_test(
op,
np_args,
np_kwargs={"axis": axes},
mlx_kwargs={"axis": axes},
torch_kwargs={"dim": axes},
torch_op="a" + op,
)
def test_arg_reduction_ops(self):
data = np.random.rand(10, 12, 13).astype(np.float32)
x = mx.array(data).astype(mx.bfloat16)
data = np.asarray(x.astype(mx.float32))
for op in ["argmin", "argmax"]:
for axis in range(3):
for kd in [True, False]:
a = getattr(mx, op)(x, axis, kd)
b = getattr(np, op)(data, axis, keepdims=kd)
a = a.astype(mx.float32)
self.assertEqual(a.tolist(), b.tolist())
for op in ["argmin", "argmax"]:
a = getattr(mx, op)(x, keepdims=True)
b = getattr(np, op)(data, keepdims=True)
a = a.astype(mx.float32)
self.assertEqual(a.tolist(), b.tolist())
a = getattr(mx, op)(x)
b = getattr(np, op)(data)
a = a.astype(mx.float32)
self.assertEqual(a.item(), b)
def test_blas_ops(self):
if mx.default_device() != mx.gpu:
return
def test_blas(shape_x, shape_y):
np.random.seed(42)
with self.subTest(shape_x=shape_x, shape_y=shape_y):
x = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_x)
y = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_y)
np_args = (
x.astype(np.float32),
y.astype(np.float32),
)
op = "matmul"
self.__default_test(op, np_args, atol_np=1e-3, atol_torch=1e-3)
for shape_x, shape_y in [
[(32, 32), (32, 32)],
[(23, 57), (57, 1)],
[(1, 3), (3, 128)],
[(8, 128, 768), (768, 16)],
]:
test_blas(shape_x, shape_y)
@unittest.skipIf(not has_torch, "requires PyTorch")
def test_conversion(self):
a_torch = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16)
a_mx = mx.array(a_torch)
expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16)
self.assertEqual(a_mx.dtype, mx.bfloat16)
self.assertTrue(mx.array_equal(a_mx, expected))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()