Fix nd ternary on GPU (#1746)

This commit is contained in:
Awni Hannun 2025-01-03 11:52:17 -08:00 committed by GitHub
parent c9d30aa6ac
commit 259025100e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 1 deletions

View File

@ -11,7 +11,7 @@
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, int) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \

View File

@ -1755,6 +1755,14 @@ class TestOps(mlx_tests.MLXTestCase):
np.where,
)
# Check non-contiguous input with several dimensions
shape = [1, 2, 2, 3, 3, 1]
strides = [16, 4, 1, 4, 1, 1]
x = mx.ones(shape=(1, 4, 4, 1))
x = mx.as_strided(x, shape, strides)
out = mx.where(mx.isnan(x), mx.nan, x)
self.assertTrue(mx.allclose(out, mx.ones_like(out)))
def test_nan_to_num(self):
a = mx.array([6, float("inf"), 2, 0])
out_mx = mx.nan_to_num(a)