Add softmin, hardshrink, hardtanh (#1180)

---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
Nikhil Mehta 2024-06-04 15:48:18 -07:00 committed by GitHub
parent 83b11bc58d
commit 0b7d71fd2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 110 additions and 20 deletions

View File

@ -9,7 +9,6 @@ from time_utils import time_fn
def bench_gelu(): def bench_gelu():
def gelu(x): def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2 return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@ -51,7 +50,6 @@ def bench_gelu():
def bench_layernorm(): def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias) mx.eval(weight, bias)

View File

@ -54,7 +54,6 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C) scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(

View File

@ -35,7 +35,6 @@ def run_bench(system_size):
def time_fft(): def time_fft():
with mx.stream(mx.cpu): with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22)) cpu_bandwidths = run_bench(system_size=int(2**22))

View File

@ -10,7 +10,6 @@ SEQ_INCREMENT = 50
def time_self_attention_primitives(): def time_self_attention_primitives():
mx.random.seed(3) mx.random.seed(3)
B = 2 B = 2
H = 38 H = 38
@ -32,7 +31,6 @@ def time_self_attention_primitives():
def time_self_attention_sdpa(): def time_self_attention_sdpa():
mx.random.seed(3) mx.random.seed(3)
B = 2 B = 2
H = 38 H = 38

View File

@ -17,6 +17,8 @@ simple functions.
gelu_approx gelu_approx
gelu_fast_approx gelu_fast_approx
glu glu
hard_shrink
hard_tanh
hardswish hardswish
leaky_relu leaky_relu
log_sigmoid log_sigmoid
@ -29,6 +31,7 @@ simple functions.
sigmoid sigmoid
silu silu
softmax softmax
softmin
softplus softplus
softshrink softshrink
step step

View File

@ -21,10 +21,15 @@ Layers
Dropout3d Dropout3d
Embedding Embedding
GELU GELU
GLU
GroupNorm GroupNorm
GRU GRU
HardShrink
HardTanh
Hardswish
InstanceNorm InstanceNorm
LayerNorm LayerNorm
LeakyReLU
Linear Linear
LSTM LSTM
MaxPool1d MaxPool1d
@ -36,13 +41,19 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU6
RNN RNN
RoPE RoPE
SELU SELU
Sequential Sequential
SiLU SiLU
SinusoidalPositionalEncoding SinusoidalPositionalEncoding
Softmin
Softshrink Softshrink
Softsign
Softmax
Softplus
Step Step
Tanh
Transformer Transformer
Upsample Upsample

View File

@ -6,7 +6,9 @@ from mlx.nn.layers.activations import (
GELU, GELU,
GLU, GLU,
SELU, SELU,
HardShrink,
Hardswish, Hardswish,
HardTanh,
LeakyReLU, LeakyReLU,
LogSigmoid, LogSigmoid,
LogSoftmax, LogSoftmax,
@ -17,6 +19,7 @@ from mlx.nn.layers.activations import (
Sigmoid, Sigmoid,
SiLU, SiLU,
Softmax, Softmax,
Softmin,
Softplus, Softplus,
Softshrink, Softshrink,
Softsign, Softsign,
@ -28,6 +31,8 @@ from mlx.nn.layers.activations import (
gelu_approx, gelu_approx,
gelu_fast_approx, gelu_fast_approx,
glu, glu,
hard_shrink,
hard_tanh,
hardswish, hardswish,
leaky_relu, leaky_relu,
log_sigmoid, log_sigmoid,
@ -40,6 +45,7 @@ from mlx.nn.layers.activations import (
sigmoid, sigmoid,
silu, silu,
softmax, softmax,
softmin,
softplus, softplus,
softshrink, softshrink,
softsign, softsign,

View File

@ -286,6 +286,38 @@ def hardswish(x):
return x * mx.minimum(max_x_3, 6) / 6 return x * mx.minimum(max_x_3, 6) / 6
@partial(mx.compile, shapeless=True)
def hard_tanh(x, min_val=-1.0, max_val=1.0):
r"""Applies the HardTanh function.
Applies :math:`\max(\min(x, \text{max\_val}), \text{min\_val})` element-wise.
"""
return mx.minimum(mx.maximum(x, min_val), max_val)
@partial(mx.compile, shapeless=True)
def hard_shrink(x, lambd=0.5):
r"""Applies the HardShrink activation function.
.. math::
\text{hardshrink}(x) = \begin{cases}
x & \text{if } x > \lambda \\
x & \text{if } x < -\lambda \\
0 & \text{otherwise}
\end{cases}
"""
return mx.where(mx.abs(x) > lambd, x, 0)
@partial(mx.compile, shapeless=True)
def softmin(x, axis=-1):
r"""Applies the Softmin function.
Applies :math:`\frac{e^{-x_i}}{\sum_j e^{-x_j}}` element-wise.
"""
return mx.softmax(-x, axis=axis)
def tanh(x): def tanh(x):
"""Applies the hyperbolic tangent function. """Applies the hyperbolic tangent function.
@ -579,3 +611,30 @@ class SELU(Module):
See :func:`selu` for the functional equivalent. See :func:`selu` for the functional equivalent.
""" """
@_make_activation_module(hard_tanh)
class HardTanh(Module):
r"""Applies the HardTanh function.
See :func:`hard_tanh` for the functional equivalent.
"""
@_make_activation_module(hard_shrink)
class HardShrink(Module):
r"""Applies the HardShrink function.
See :func:`hard_shrink` for the functional equivalent.
Args:
lambd: the :math:`\lambda` value for Hardshrink. Default: ``0.5``
"""
@_make_activation_module(softmin)
class Softmin(Module):
r"""Applies the Softmin function.
See :func:`softmin` for the functional equivalent.
"""

View File

@ -950,7 +950,6 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices=lhs_indices, lhs_indices=lhs_indices,
rhs_indices=rhs_indices, rhs_indices=rhs_indices,
): ):
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype) a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype) b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
@ -1066,7 +1065,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
def test_gather_matmul_grad(self): def test_gather_matmul_grad(self):
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32) lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32) rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)

View File

@ -382,7 +382,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
def test_compile_kwargs(self): def test_compile_kwargs(self):
@mx.compile @mx.compile
def fun(x, y, z): def fun(x, y, z):
return x + y + z return x + y + z
@ -479,7 +478,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_compile_with_constant(self): def test_compile_with_constant(self):
# Test float # Test float
@partial(mx.compile) @partial(mx.compile)
def fun(x, y): def fun(x, y):
@ -582,7 +580,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(counter[0], 2) self.assertEqual(counter[0], 2)
def test_compile_inf(self): def test_compile_inf(self):
@mx.compile @mx.compile
def fun(x): def fun(x):
return mx.isinf(x + 2) return mx.isinf(x + 2)
@ -591,7 +588,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(out.item(), False) self.assertEqual(out.item(), False)
def test_unsupported_input_types(self): def test_unsupported_input_types(self):
class MyClass: class MyClass:
value = 1 value = 1

View File

@ -672,7 +672,6 @@ class TestConv(mlx_tests.MLXTestCase):
np_dtype=np.float32, np_dtype=np.float32,
atol=1e-5, atol=1e-5,
): ):
with self.subTest( with self.subTest(
in_shape=in_shape, in_shape=in_shape,
wt_shape=wt_shape, wt_shape=wt_shape,
@ -684,7 +683,6 @@ class TestConv(mlx_tests.MLXTestCase):
flip=flip, flip=flip,
np_dtype=np_dtype, np_dtype=np_dtype,
): ):
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:])) scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype) in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype) wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
@ -710,7 +708,6 @@ class TestConv(mlx_tests.MLXTestCase):
def conv_general_pt( def conv_general_pt(
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
): ):
C = inp.size()[1] C = inp.size()[1]
ndim = inp.ndim - 2 ndim = inp.ndim - 2
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x map_ints = lambda x: [x] * ndim if isinstance(x, int) else x

View File

@ -15,7 +15,6 @@ def mlx_primitives_sdpa(q, k, v, scale):
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0) # SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
def mlx_primitives_sdpa_with_gqa(q, k, v, scale): def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
n_repeats = q.shape[1] // k.shape[1] n_repeats = q.shape[1] // k.shape[1]
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py # borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
@ -34,7 +33,6 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase): class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
def test_fast_sdpa(self): def test_fast_sdpa(self):
# Not yet supported: # Not yet supported:
# * K pre-transposed in kernel, V pre-transposed in kernel # * K pre-transposed in kernel, V pre-transposed in kernel
np.random.seed(0) np.random.seed(0)

View File

@ -7,7 +7,6 @@ import mlx_tests
class TestMetal(mlx_tests.MLXTestCase): class TestMetal(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self): def test_memory_info(self):
old_limit = mx.metal.set_cache_limit(0) old_limit = mx.metal.set_cache_limit(0)

View File

@ -55,9 +55,7 @@ class TestBase(mlx_tests.MLXTestCase):
m.apply_to_modules(assert_training) m.apply_to_modules(assert_training)
def test_module_attributes(self): def test_module_attributes(self):
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.val = None self.val = None
@ -806,6 +804,15 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,)) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_softmin(self):
x = mx.array([1.0, 2.0, 3.0])
y = nn.softmin(x)
epsilon = 1e-4
expected_y = mx.array([0.6652, 0.2447, 0.0900])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self): def test_softplus(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.softplus(x) y = nn.softplus(x)
@ -899,6 +906,28 @@ class TestLayers(mlx_tests.MLXTestCase):
out = nn.glu(x) out = nn.glu(x)
self.assertEqualArray(out, y) self.assertEqualArray(out, y)
def test_hard_tanh(self):
x = mx.array([1.0, -2.0, 0.0, 0.5, 2.0])
y = nn.hard_tanh(x)
expected_y = mx.array([1.0, -1.0, 0.0, 0.5, 1.0])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_hard_shrink(self):
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
y = nn.hard_shrink(x)
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
y = nn.hard_shrink(x, lambd=1.0)
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_rope(self): def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]: for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
rope = nn.RoPE(4, **kwargs) rope = nn.RoPE(4, **kwargs)