mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 07:31:26 +08:00
parent
aac2f9fb61
commit
76e63212ff
@ -451,7 +451,7 @@ instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSu
|
||||
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
|
||||
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
|
||||
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
|
||||
@ -464,7 +464,7 @@ instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumP
|
||||
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
|
||||
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
|
||||
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
|
||||
@ -477,7 +477,7 @@ instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMa
|
||||
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
|
||||
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
|
||||
//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
|
||||
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
|
||||
@ -490,5 +490,5 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
|
||||
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||
//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
|
||||
|
@ -1564,6 +1564,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
c2 = mxop(a_mlx, axis=0, inclusive=False, reverse=True)[:-1, :, :]
|
||||
self.assertTrue(mx.array_equal(c1, c2))
|
||||
|
||||
a = mx.random.uniform(shape=(8, 32))
|
||||
mat = mx.tri(32)
|
||||
for t in [mx.float16, mx.bfloat16]:
|
||||
a_t = a.astype(t)
|
||||
mat_t = mat.astype(t)
|
||||
out = mx.cumsum(a_t, axis=-1)
|
||||
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
|
||||
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
|
||||
|
||||
def test_squeeze_expand(self):
|
||||
a = mx.zeros((2, 1, 2, 1))
|
||||
self.assertEqual(mx.squeeze(a).shape, (2, 2))
|
||||
|
Loading…
Reference in New Issue
Block a user