diff --git a/benchmarks/python/compile_bench.py b/benchmarks/python/compile_bench.py index 0d5d9f61d..a204ee439 100644 --- a/benchmarks/python/compile_bench.py +++ b/benchmarks/python/compile_bench.py @@ -9,7 +9,6 @@ from time_utils import time_fn def bench_gelu(): - def gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 @@ -51,7 +50,6 @@ def bench_gelu(): def bench_layernorm(): - weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) mx.eval(weight, bias) diff --git a/benchmarks/python/conv_bench.py b/benchmarks/python/conv_bench.py index 7ca4d8a7c..f421963ad 100644 --- a/benchmarks/python/conv_bench.py +++ b/benchmarks/python/conv_bench.py @@ -54,7 +54,6 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): - scale = 1.0 / math.sqrt(kH * kH * C) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( diff --git a/benchmarks/python/fft_bench.py b/benchmarks/python/fft_bench.py index 391a28aec..865d6408f 100644 --- a/benchmarks/python/fft_bench.py +++ b/benchmarks/python/fft_bench.py @@ -35,7 +35,6 @@ def run_bench(system_size): def time_fft(): - with mx.stream(mx.cpu): cpu_bandwidths = run_bench(system_size=int(2**22)) diff --git a/benchmarks/python/sdpa_bench.py b/benchmarks/python/sdpa_bench.py index fcaad7b6a..0e81b6dfa 100644 --- a/benchmarks/python/sdpa_bench.py +++ b/benchmarks/python/sdpa_bench.py @@ -10,7 +10,6 @@ SEQ_INCREMENT = 50 def time_self_attention_primitives(): - mx.random.seed(3) B = 2 H = 38 @@ -32,7 +31,6 @@ def time_self_attention_primitives(): def time_self_attention_sdpa(): - mx.random.seed(3) B = 2 H = 38 diff --git a/docs/src/python/nn/functions.rst b/docs/src/python/nn/functions.rst index db276afdf..f1077776b 100644 --- a/docs/src/python/nn/functions.rst +++ b/docs/src/python/nn/functions.rst @@ -17,6 +17,8 @@ simple functions. gelu_approx gelu_fast_approx glu + hard_shrink + hard_tanh hardswish leaky_relu log_sigmoid @@ -29,6 +31,7 @@ simple functions. sigmoid silu softmax + softmin softplus softshrink step diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index cbbbb5c3b..d06e009e8 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -21,10 +21,15 @@ Layers Dropout3d Embedding GELU + GLU GroupNorm GRU + HardShrink + HardTanh + Hardswish InstanceNorm LayerNorm + LeakyReLU Linear LSTM MaxPool1d @@ -36,13 +41,19 @@ Layers QuantizedLinear RMSNorm ReLU + ReLU6 RNN RoPE SELU Sequential SiLU SinusoidalPositionalEncoding + Softmin Softshrink + Softsign + Softmax + Softplus Step + Tanh Transformer Upsample diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index be4935626..f528c9908 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -6,7 +6,9 @@ from mlx.nn.layers.activations import ( GELU, GLU, SELU, + HardShrink, Hardswish, + HardTanh, LeakyReLU, LogSigmoid, LogSoftmax, @@ -17,6 +19,7 @@ from mlx.nn.layers.activations import ( Sigmoid, SiLU, Softmax, + Softmin, Softplus, Softshrink, Softsign, @@ -28,6 +31,8 @@ from mlx.nn.layers.activations import ( gelu_approx, gelu_fast_approx, glu, + hard_shrink, + hard_tanh, hardswish, leaky_relu, log_sigmoid, @@ -40,6 +45,7 @@ from mlx.nn.layers.activations import ( sigmoid, silu, softmax, + softmin, softplus, softshrink, softsign, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 88ff3476a..ff03a29ed 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -286,6 +286,38 @@ def hardswish(x): return x * mx.minimum(max_x_3, 6) / 6 +@partial(mx.compile, shapeless=True) +def hard_tanh(x, min_val=-1.0, max_val=1.0): + r"""Applies the HardTanh function. + + Applies :math:`\max(\min(x, \text{max\_val}), \text{min\_val})` element-wise. + """ + return mx.minimum(mx.maximum(x, min_val), max_val) + + +@partial(mx.compile, shapeless=True) +def hard_shrink(x, lambd=0.5): + r"""Applies the HardShrink activation function. + + .. math:: + \text{hardshrink}(x) = \begin{cases} + x & \text{if } x > \lambda \\ + x & \text{if } x < -\lambda \\ + 0 & \text{otherwise} + \end{cases} + """ + return mx.where(mx.abs(x) > lambd, x, 0) + + +@partial(mx.compile, shapeless=True) +def softmin(x, axis=-1): + r"""Applies the Softmin function. + + Applies :math:`\frac{e^{-x_i}}{\sum_j e^{-x_j}}` element-wise. + """ + return mx.softmax(-x, axis=axis) + + def tanh(x): """Applies the hyperbolic tangent function. @@ -579,3 +611,30 @@ class SELU(Module): See :func:`selu` for the functional equivalent. """ + + +@_make_activation_module(hard_tanh) +class HardTanh(Module): + r"""Applies the HardTanh function. + + See :func:`hard_tanh` for the functional equivalent. + """ + + +@_make_activation_module(hard_shrink) +class HardShrink(Module): + r"""Applies the HardShrink function. + + See :func:`hard_shrink` for the functional equivalent. + + Args: + lambd: the :math:`\lambda` value for Hardshrink. Default: ``0.5`` + """ + + +@_make_activation_module(softmin) +class Softmin(Module): + r"""Applies the Softmin function. + + See :func:`softmin` for the functional equivalent. + """ diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 20df4341e..b645011b2 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -950,7 +950,6 @@ class TestBlas(mlx_tests.MLXTestCase): lhs_indices=lhs_indices, rhs_indices=rhs_indices, ): - a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype) b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype) @@ -1066,7 +1065,6 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) def test_gather_matmul_grad(self): - lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32) rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 5018e47f4..fb058203e 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -382,7 +382,6 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) def test_compile_kwargs(self): - @mx.compile def fun(x, y, z): return x + y + z @@ -479,7 +478,6 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) def test_compile_with_constant(self): - # Test float @partial(mx.compile) def fun(x, y): @@ -582,7 +580,6 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertEqual(counter[0], 2) def test_compile_inf(self): - @mx.compile def fun(x): return mx.isinf(x + 2) @@ -591,7 +588,6 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertEqual(out.item(), False) def test_unsupported_input_types(self): - class MyClass: value = 1 diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 754f1727e..b54619671 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -672,7 +672,6 @@ class TestConv(mlx_tests.MLXTestCase): np_dtype=np.float32, atol=1e-5, ): - with self.subTest( in_shape=in_shape, wt_shape=wt_shape, @@ -684,7 +683,6 @@ class TestConv(mlx_tests.MLXTestCase): flip=flip, np_dtype=np_dtype, ): - scale = 1.0 / math.sqrt(np.prod(wt_shape[1:])) in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype) wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype) @@ -710,7 +708,6 @@ class TestConv(mlx_tests.MLXTestCase): def conv_general_pt( inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip ): - C = inp.size()[1] ndim = inp.ndim - 2 map_ints = lambda x: [x] * ndim if isinstance(x, int) else x diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 51b2c047c..4baed56d7 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -15,7 +15,6 @@ def mlx_primitives_sdpa(q, k, v, scale): # SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0) def mlx_primitives_sdpa_with_gqa(q, k, v, scale): - n_repeats = q.shape[1] // k.shape[1] # borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py @@ -34,7 +33,6 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale): class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase): def test_fast_sdpa(self): - # Not yet supported: # * K pre-transposed in kernel, V pre-transposed in kernel np.random.seed(0) diff --git a/python/tests/test_metal.py b/python/tests/test_metal.py index bd091f8e2..be63b118e 100644 --- a/python/tests/test_metal.py +++ b/python/tests/test_metal.py @@ -7,7 +7,6 @@ import mlx_tests class TestMetal(mlx_tests.MLXTestCase): - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_memory_info(self): old_limit = mx.metal.set_cache_limit(0) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 25540d8d1..27537cac4 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -55,9 +55,7 @@ class TestBase(mlx_tests.MLXTestCase): m.apply_to_modules(assert_training) def test_module_attributes(self): - class Model(nn.Module): - def __init__(self): super().__init__() self.val = None @@ -806,6 +804,15 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) + def test_softmin(self): + x = mx.array([1.0, 2.0, 3.0]) + y = nn.softmin(x) + epsilon = 1e-4 + expected_y = mx.array([0.6652, 0.2447, 0.0900]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y.dtype, mx.float32) + def test_softplus(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.softplus(x) @@ -899,6 +906,28 @@ class TestLayers(mlx_tests.MLXTestCase): out = nn.glu(x) self.assertEqualArray(out, y) + def test_hard_tanh(self): + x = mx.array([1.0, -2.0, 0.0, 0.5, 2.0]) + y = nn.hard_tanh(x) + expected_y = mx.array([1.0, -1.0, 0.0, 0.5, 1.0]) + self.assertTrue(mx.array_equal(y, expected_y)) + self.assertEqual(y.shape, (5,)) + self.assertEqual(y.dtype, mx.float32) + + def test_hard_shrink(self): + x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5]) + y = nn.hard_shrink(x) + expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5]) + self.assertTrue(mx.array_equal(y, expected_y)) + self.assertEqual(y.shape, (5,)) + self.assertEqual(y.dtype, mx.float32) + + y = nn.hard_shrink(x, lambd=1.0) + expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5]) + self.assertTrue(mx.array_equal(y, expected_y)) + self.assertEqual(y.shape, (5,)) + self.assertEqual(y.dtype, mx.float32) + def test_rope(self): for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]: rope = nn.RoPE(4, **kwargs)