diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 79eebf4f9..52be03d5c 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -84,9 +84,9 @@ void binary_op( } // Launch up to 3D grid of threads - int dim0 = ndim > 0 ? shape[ndim - 1] : 1; - int dim1 = ndim > 1 ? shape[ndim - 2] : 1; - int rest = out.size() / (dim0 * dim1); + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size != 1024) { throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 75e9aac60..106e11cbd 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -22,7 +22,6 @@ def sigmoid(x): \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} """ return mx.sigmoid(x) - def relu(x): @@ -89,10 +88,12 @@ def gelu_fast_approx(x): """ return x * mx.sigmoid(1.773 * x) + @_make_activation_module class Sigmoid(Module): pass + @_make_activation_module(relu) class ReLU(Module): pass diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d7ae6228c..65b9daed5 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1305,6 +1305,11 @@ class TestOps(mlx_tests.MLXTestCase): d_np = np.take(b_mx, np.arange(kth), axis=axis) self.assertTrue(np.all(d_np <= c_mx)) + def test_large_binary(self): + a = mx.ones([1000, 2147484], mx.int8) + b = mx.ones([2147484], mx.int8) + self.assertEqual((a + b)[0, 0].item(), 2) + if __name__ == "__main__": unittest.main()