Enable bfloat scan (#974)

* enable bfloat scan
* fix tests
This commit is contained in:
Awni Hannun
2024-04-08 12:29:19 -07:00
committed by GitHub
parent aac2f9fb61
commit 76e63212ff
2 changed files with 13 additions and 4 deletions

View File

@@ -1564,6 +1564,15 @@ class TestOps(mlx_tests.MLXTestCase):
c2 = mxop(a_mlx, axis=0, inclusive=False, reverse=True)[:-1, :, :]
self.assertTrue(mx.array_equal(c1, c2))
a = mx.random.uniform(shape=(8, 32))
mat = mx.tri(32)
for t in [mx.float16, mx.bfloat16]:
a_t = a.astype(t)
mat_t = mat.astype(t)
out = mx.cumsum(a_t, axis=-1)
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
def test_squeeze_expand(self):
a = mx.zeros((2, 1, 2, 1))
self.assertEqual(mx.squeeze(a).shape, (2, 2))