From 76e63212fff7c78c191c106ffcfaa4f6dcc1abe2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 8 Apr 2024 12:29:19 -0700 Subject: [PATCH] Enable bfloat scan (#974) * enable bfloat scan * fix tests --- mlx/backend/metal/kernels/scan.metal | 8 ++++---- python/tests/test_ops.py | 9 +++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index e692449c7..13b7e7adc 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -451,7 +451,7 @@ instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSu //instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2) instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4) instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4) -//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4) +instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4) //instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum) //instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4) instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4) @@ -464,7 +464,7 @@ instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumP //instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2) instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4) instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4) -//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4) +instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4) //instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd) //instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4) instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4) @@ -477,7 +477,7 @@ instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMa //instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2) instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4) instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4) -//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4) +instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4) //instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax) //instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4) instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4) @@ -490,5 +490,5 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi //instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2) instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) -//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) +instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) //instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 7f2e98e70..e2bd749bc 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1564,6 +1564,15 @@ class TestOps(mlx_tests.MLXTestCase): c2 = mxop(a_mlx, axis=0, inclusive=False, reverse=True)[:-1, :, :] self.assertTrue(mx.array_equal(c1, c2)) + a = mx.random.uniform(shape=(8, 32)) + mat = mx.tri(32) + for t in [mx.float16, mx.bfloat16]: + a_t = a.astype(t) + mat_t = mat.astype(t) + out = mx.cumsum(a_t, axis=-1) + expected = (mat_t * a_t[:, None, :]).sum(axis=-1) + self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3)) + def test_squeeze_expand(self): a = mx.zeros((2, 1, 2, 1)) self.assertEqual(mx.squeeze(a).shape, (2, 2))