Add softmin, hardshrink, hardtanh (#1180)

---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
Nikhil Mehta 2024-06-04 15:48:18 -07:00 committed by GitHub
parent 83b11bc58d
commit 0b7d71fd2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 110 additions and 20 deletions

View File

@ -9,7 +9,6 @@ from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@ -51,7 +50,6 @@ def bench_gelu():
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)

View File

@ -54,7 +54,6 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(

View File

@ -35,7 +35,6 @@ def run_bench(system_size):
def time_fft():
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22))

View File

@ -10,7 +10,6 @@ SEQ_INCREMENT = 50
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
@ -32,7 +31,6 @@ def time_self_attention_primitives():
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38

View File

@ -17,6 +17,8 @@ simple functions.
gelu_approx
gelu_fast_approx
glu
hard_shrink
hard_tanh
hardswish
leaky_relu
log_sigmoid
@ -29,6 +31,7 @@ simple functions.
sigmoid
silu
softmax
softmin
softplus
softshrink
step

View File

@ -21,10 +21,15 @@ Layers
Dropout3d
Embedding
GELU
GLU
GroupNorm
GRU
HardShrink
HardTanh
Hardswish
InstanceNorm
LayerNorm
LeakyReLU
Linear
LSTM
MaxPool1d
@ -36,13 +41,19 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU6
RNN
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Softsign
Softmax
Softplus
Step
Tanh
Transformer
Upsample

View File

@ -6,7 +6,9 @@ from mlx.nn.layers.activations import (
GELU,
GLU,
SELU,
HardShrink,
Hardswish,
HardTanh,
LeakyReLU,
LogSigmoid,
LogSoftmax,
@ -17,6 +19,7 @@ from mlx.nn.layers.activations import (
Sigmoid,
SiLU,
Softmax,
Softmin,
Softplus,
Softshrink,
Softsign,
@ -28,6 +31,8 @@ from mlx.nn.layers.activations import (
gelu_approx,
gelu_fast_approx,
glu,
hard_shrink,
hard_tanh,
hardswish,
leaky_relu,
log_sigmoid,
@ -40,6 +45,7 @@ from mlx.nn.layers.activations import (
sigmoid,
silu,
softmax,
softmin,
softplus,
softshrink,
softsign,

View File

@ -286,6 +286,38 @@ def hardswish(x):
return x * mx.minimum(max_x_3, 6) / 6
@partial(mx.compile, shapeless=True)
def hard_tanh(x, min_val=-1.0, max_val=1.0):
r"""Applies the HardTanh function.
Applies :math:`\max(\min(x, \text{max\_val}), \text{min\_val})` element-wise.
"""
return mx.minimum(mx.maximum(x, min_val), max_val)
@partial(mx.compile, shapeless=True)
def hard_shrink(x, lambd=0.5):
r"""Applies the HardShrink activation function.
.. math::
\text{hardshrink}(x) = \begin{cases}
x & \text{if } x > \lambda \\
x & \text{if } x < -\lambda \\
0 & \text{otherwise}
\end{cases}
"""
return mx.where(mx.abs(x) > lambd, x, 0)
@partial(mx.compile, shapeless=True)
def softmin(x, axis=-1):
r"""Applies the Softmin function.
Applies :math:`\frac{e^{-x_i}}{\sum_j e^{-x_j}}` element-wise.
"""
return mx.softmax(-x, axis=axis)
def tanh(x):
"""Applies the hyperbolic tangent function.
@ -579,3 +611,30 @@ class SELU(Module):
See :func:`selu` for the functional equivalent.
"""
@_make_activation_module(hard_tanh)
class HardTanh(Module):
r"""Applies the HardTanh function.
See :func:`hard_tanh` for the functional equivalent.
"""
@_make_activation_module(hard_shrink)
class HardShrink(Module):
r"""Applies the HardShrink function.
See :func:`hard_shrink` for the functional equivalent.
Args:
lambd: the :math:`\lambda` value for Hardshrink. Default: ``0.5``
"""
@_make_activation_module(softmin)
class Softmin(Module):
r"""Applies the Softmin function.
See :func:`softmin` for the functional equivalent.
"""

View File

@ -950,7 +950,6 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices=lhs_indices,
rhs_indices=rhs_indices,
):
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
@ -1066,7 +1065,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
def test_gather_matmul_grad(self):
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)

View File

@ -382,7 +382,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
def test_compile_kwargs(self):
@mx.compile
def fun(x, y, z):
return x + y + z
@ -479,7 +478,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
def fun(x, y):
@ -582,7 +580,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(counter[0], 2)
def test_compile_inf(self):
@mx.compile
def fun(x):
return mx.isinf(x + 2)
@ -591,7 +588,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(out.item(), False)
def test_unsupported_input_types(self):
class MyClass:
value = 1

View File

@ -672,7 +672,6 @@ class TestConv(mlx_tests.MLXTestCase):
np_dtype=np.float32,
atol=1e-5,
):
with self.subTest(
in_shape=in_shape,
wt_shape=wt_shape,
@ -684,7 +683,6 @@ class TestConv(mlx_tests.MLXTestCase):
flip=flip,
np_dtype=np_dtype,
):
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
@ -710,7 +708,6 @@ class TestConv(mlx_tests.MLXTestCase):
def conv_general_pt(
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
):
C = inp.size()[1]
ndim = inp.ndim - 2
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x

View File

@ -15,7 +15,6 @@ def mlx_primitives_sdpa(q, k, v, scale):
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
n_repeats = q.shape[1] // k.shape[1]
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
@ -34,7 +33,6 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
def test_fast_sdpa(self):
# Not yet supported:
# * K pre-transposed in kernel, V pre-transposed in kernel
np.random.seed(0)

View File

@ -7,7 +7,6 @@ import mlx_tests
class TestMetal(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.metal.set_cache_limit(0)

View File

@ -55,9 +55,7 @@ class TestBase(mlx_tests.MLXTestCase):
m.apply_to_modules(assert_training)
def test_module_attributes(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.val = None
@ -806,6 +804,15 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softmin(self):
x = mx.array([1.0, 2.0, 3.0])
y = nn.softmin(x)
epsilon = 1e-4
expected_y = mx.array([0.6652, 0.2447, 0.0900])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softplus(x)
@ -899,6 +906,28 @@ class TestLayers(mlx_tests.MLXTestCase):
out = nn.glu(x)
self.assertEqualArray(out, y)
def test_hard_tanh(self):
x = mx.array([1.0, -2.0, 0.0, 0.5, 2.0])
y = nn.hard_tanh(x)
expected_y = mx.array([1.0, -1.0, 0.0, 0.5, 1.0])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_hard_shrink(self):
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
y = nn.hard_shrink(x)
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
y = nn.hard_shrink(x, lambd=1.0)
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
rope = nn.RoPE(4, **kwargs)