diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index 761faa9b8..cceb53061 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -11,7 +11,7 @@ #define instantiate_ternary_all(op, tname, type) \ instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ - instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, int) \ + instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \ instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \ diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 6d01ee5a2..12be3d5fd 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1755,6 +1755,14 @@ class TestOps(mlx_tests.MLXTestCase): np.where, ) + # Check non-contiguous input with several dimensions + shape = [1, 2, 2, 3, 3, 1] + strides = [16, 4, 1, 4, 1, 1] + x = mx.ones(shape=(1, 4, 4, 1)) + x = mx.as_strided(x, shape, strides) + out = mx.where(mx.isnan(x), mx.nan, x) + self.assertTrue(mx.allclose(out, mx.ones_like(out))) + def test_nan_to_num(self): a = mx.array([6, float("inf"), 2, 0]) out_mx = mx.nan_to_num(a)