mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add softmin, hardshrink, hardtanh (#1180)
--------- Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
parent
83b11bc58d
commit
0b7d71fd2f
@ -9,7 +9,6 @@ from time_utils import time_fn
|
|||||||
|
|
||||||
|
|
||||||
def bench_gelu():
|
def bench_gelu():
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
@ -51,7 +50,6 @@ def bench_gelu():
|
|||||||
|
|
||||||
|
|
||||||
def bench_layernorm():
|
def bench_layernorm():
|
||||||
|
|
||||||
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
mx.eval(weight, bias)
|
mx.eval(weight, bias)
|
||||||
|
@ -54,7 +54,6 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|||||||
|
|
||||||
|
|
||||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
|
||||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
@ -35,7 +35,6 @@ def run_bench(system_size):
|
|||||||
|
|
||||||
|
|
||||||
def time_fft():
|
def time_fft():
|
||||||
|
|
||||||
with mx.stream(mx.cpu):
|
with mx.stream(mx.cpu):
|
||||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ SEQ_INCREMENT = 50
|
|||||||
|
|
||||||
|
|
||||||
def time_self_attention_primitives():
|
def time_self_attention_primitives():
|
||||||
|
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
B = 2
|
B = 2
|
||||||
H = 38
|
H = 38
|
||||||
@ -32,7 +31,6 @@ def time_self_attention_primitives():
|
|||||||
|
|
||||||
|
|
||||||
def time_self_attention_sdpa():
|
def time_self_attention_sdpa():
|
||||||
|
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
B = 2
|
B = 2
|
||||||
H = 38
|
H = 38
|
||||||
|
@ -17,6 +17,8 @@ simple functions.
|
|||||||
gelu_approx
|
gelu_approx
|
||||||
gelu_fast_approx
|
gelu_fast_approx
|
||||||
glu
|
glu
|
||||||
|
hard_shrink
|
||||||
|
hard_tanh
|
||||||
hardswish
|
hardswish
|
||||||
leaky_relu
|
leaky_relu
|
||||||
log_sigmoid
|
log_sigmoid
|
||||||
@ -29,6 +31,7 @@ simple functions.
|
|||||||
sigmoid
|
sigmoid
|
||||||
silu
|
silu
|
||||||
softmax
|
softmax
|
||||||
|
softmin
|
||||||
softplus
|
softplus
|
||||||
softshrink
|
softshrink
|
||||||
step
|
step
|
||||||
|
@ -21,10 +21,15 @@ Layers
|
|||||||
Dropout3d
|
Dropout3d
|
||||||
Embedding
|
Embedding
|
||||||
GELU
|
GELU
|
||||||
|
GLU
|
||||||
GroupNorm
|
GroupNorm
|
||||||
GRU
|
GRU
|
||||||
|
HardShrink
|
||||||
|
HardTanh
|
||||||
|
Hardswish
|
||||||
InstanceNorm
|
InstanceNorm
|
||||||
LayerNorm
|
LayerNorm
|
||||||
|
LeakyReLU
|
||||||
Linear
|
Linear
|
||||||
LSTM
|
LSTM
|
||||||
MaxPool1d
|
MaxPool1d
|
||||||
@ -36,13 +41,19 @@ Layers
|
|||||||
QuantizedLinear
|
QuantizedLinear
|
||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
|
ReLU6
|
||||||
RNN
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
SELU
|
SELU
|
||||||
Sequential
|
Sequential
|
||||||
SiLU
|
SiLU
|
||||||
SinusoidalPositionalEncoding
|
SinusoidalPositionalEncoding
|
||||||
|
Softmin
|
||||||
Softshrink
|
Softshrink
|
||||||
|
Softsign
|
||||||
|
Softmax
|
||||||
|
Softplus
|
||||||
Step
|
Step
|
||||||
|
Tanh
|
||||||
Transformer
|
Transformer
|
||||||
Upsample
|
Upsample
|
||||||
|
@ -6,7 +6,9 @@ from mlx.nn.layers.activations import (
|
|||||||
GELU,
|
GELU,
|
||||||
GLU,
|
GLU,
|
||||||
SELU,
|
SELU,
|
||||||
|
HardShrink,
|
||||||
Hardswish,
|
Hardswish,
|
||||||
|
HardTanh,
|
||||||
LeakyReLU,
|
LeakyReLU,
|
||||||
LogSigmoid,
|
LogSigmoid,
|
||||||
LogSoftmax,
|
LogSoftmax,
|
||||||
@ -17,6 +19,7 @@ from mlx.nn.layers.activations import (
|
|||||||
Sigmoid,
|
Sigmoid,
|
||||||
SiLU,
|
SiLU,
|
||||||
Softmax,
|
Softmax,
|
||||||
|
Softmin,
|
||||||
Softplus,
|
Softplus,
|
||||||
Softshrink,
|
Softshrink,
|
||||||
Softsign,
|
Softsign,
|
||||||
@ -28,6 +31,8 @@ from mlx.nn.layers.activations import (
|
|||||||
gelu_approx,
|
gelu_approx,
|
||||||
gelu_fast_approx,
|
gelu_fast_approx,
|
||||||
glu,
|
glu,
|
||||||
|
hard_shrink,
|
||||||
|
hard_tanh,
|
||||||
hardswish,
|
hardswish,
|
||||||
leaky_relu,
|
leaky_relu,
|
||||||
log_sigmoid,
|
log_sigmoid,
|
||||||
@ -40,6 +45,7 @@ from mlx.nn.layers.activations import (
|
|||||||
sigmoid,
|
sigmoid,
|
||||||
silu,
|
silu,
|
||||||
softmax,
|
softmax,
|
||||||
|
softmin,
|
||||||
softplus,
|
softplus,
|
||||||
softshrink,
|
softshrink,
|
||||||
softsign,
|
softsign,
|
||||||
|
@ -286,6 +286,38 @@ def hardswish(x):
|
|||||||
return x * mx.minimum(max_x_3, 6) / 6
|
return x * mx.minimum(max_x_3, 6) / 6
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def hard_tanh(x, min_val=-1.0, max_val=1.0):
|
||||||
|
r"""Applies the HardTanh function.
|
||||||
|
|
||||||
|
Applies :math:`\max(\min(x, \text{max\_val}), \text{min\_val})` element-wise.
|
||||||
|
"""
|
||||||
|
return mx.minimum(mx.maximum(x, min_val), max_val)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def hard_shrink(x, lambd=0.5):
|
||||||
|
r"""Applies the HardShrink activation function.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{hardshrink}(x) = \begin{cases}
|
||||||
|
x & \text{if } x > \lambda \\
|
||||||
|
x & \text{if } x < -\lambda \\
|
||||||
|
0 & \text{otherwise}
|
||||||
|
\end{cases}
|
||||||
|
"""
|
||||||
|
return mx.where(mx.abs(x) > lambd, x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def softmin(x, axis=-1):
|
||||||
|
r"""Applies the Softmin function.
|
||||||
|
|
||||||
|
Applies :math:`\frac{e^{-x_i}}{\sum_j e^{-x_j}}` element-wise.
|
||||||
|
"""
|
||||||
|
return mx.softmax(-x, axis=axis)
|
||||||
|
|
||||||
|
|
||||||
def tanh(x):
|
def tanh(x):
|
||||||
"""Applies the hyperbolic tangent function.
|
"""Applies the hyperbolic tangent function.
|
||||||
|
|
||||||
@ -579,3 +611,30 @@ class SELU(Module):
|
|||||||
|
|
||||||
See :func:`selu` for the functional equivalent.
|
See :func:`selu` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(hard_tanh)
|
||||||
|
class HardTanh(Module):
|
||||||
|
r"""Applies the HardTanh function.
|
||||||
|
|
||||||
|
See :func:`hard_tanh` for the functional equivalent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(hard_shrink)
|
||||||
|
class HardShrink(Module):
|
||||||
|
r"""Applies the HardShrink function.
|
||||||
|
|
||||||
|
See :func:`hard_shrink` for the functional equivalent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lambd: the :math:`\lambda` value for Hardshrink. Default: ``0.5``
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(softmin)
|
||||||
|
class Softmin(Module):
|
||||||
|
r"""Applies the Softmin function.
|
||||||
|
|
||||||
|
See :func:`softmin` for the functional equivalent.
|
||||||
|
"""
|
||||||
|
@ -950,7 +950,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
lhs_indices=lhs_indices,
|
lhs_indices=lhs_indices,
|
||||||
rhs_indices=rhs_indices,
|
rhs_indices=rhs_indices,
|
||||||
):
|
):
|
||||||
|
|
||||||
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
|
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
|
||||||
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
|
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
|
||||||
|
|
||||||
@ -1066,7 +1065,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||||
|
|
||||||
def test_gather_matmul_grad(self):
|
def test_gather_matmul_grad(self):
|
||||||
|
|
||||||
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
||||||
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||||
|
|
||||||
|
@ -382,7 +382,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||||
|
|
||||||
def test_compile_kwargs(self):
|
def test_compile_kwargs(self):
|
||||||
|
|
||||||
@mx.compile
|
@mx.compile
|
||||||
def fun(x, y, z):
|
def fun(x, y, z):
|
||||||
return x + y + z
|
return x + y + z
|
||||||
@ -479,7 +478,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||||
|
|
||||||
def test_compile_with_constant(self):
|
def test_compile_with_constant(self):
|
||||||
|
|
||||||
# Test float
|
# Test float
|
||||||
@partial(mx.compile)
|
@partial(mx.compile)
|
||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
@ -582,7 +580,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(counter[0], 2)
|
self.assertEqual(counter[0], 2)
|
||||||
|
|
||||||
def test_compile_inf(self):
|
def test_compile_inf(self):
|
||||||
|
|
||||||
@mx.compile
|
@mx.compile
|
||||||
def fun(x):
|
def fun(x):
|
||||||
return mx.isinf(x + 2)
|
return mx.isinf(x + 2)
|
||||||
@ -591,7 +588,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(out.item(), False)
|
self.assertEqual(out.item(), False)
|
||||||
|
|
||||||
def test_unsupported_input_types(self):
|
def test_unsupported_input_types(self):
|
||||||
|
|
||||||
class MyClass:
|
class MyClass:
|
||||||
value = 1
|
value = 1
|
||||||
|
|
||||||
|
@ -672,7 +672,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
np_dtype=np.float32,
|
np_dtype=np.float32,
|
||||||
atol=1e-5,
|
atol=1e-5,
|
||||||
):
|
):
|
||||||
|
|
||||||
with self.subTest(
|
with self.subTest(
|
||||||
in_shape=in_shape,
|
in_shape=in_shape,
|
||||||
wt_shape=wt_shape,
|
wt_shape=wt_shape,
|
||||||
@ -684,7 +683,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
flip=flip,
|
flip=flip,
|
||||||
np_dtype=np_dtype,
|
np_dtype=np_dtype,
|
||||||
):
|
):
|
||||||
|
|
||||||
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
||||||
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
|
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
|
||||||
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
|
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
|
||||||
@ -710,7 +708,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
def conv_general_pt(
|
def conv_general_pt(
|
||||||
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
|
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
|
||||||
):
|
):
|
||||||
|
|
||||||
C = inp.size()[1]
|
C = inp.size()[1]
|
||||||
ndim = inp.ndim - 2
|
ndim = inp.ndim - 2
|
||||||
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
|
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
|
||||||
|
@ -15,7 +15,6 @@ def mlx_primitives_sdpa(q, k, v, scale):
|
|||||||
|
|
||||||
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
|
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
|
||||||
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||||
|
|
||||||
n_repeats = q.shape[1] // k.shape[1]
|
n_repeats = q.shape[1] // k.shape[1]
|
||||||
|
|
||||||
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
|
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
|
||||||
@ -34,7 +33,6 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
|||||||
|
|
||||||
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||||
def test_fast_sdpa(self):
|
def test_fast_sdpa(self):
|
||||||
|
|
||||||
# Not yet supported:
|
# Not yet supported:
|
||||||
# * K pre-transposed in kernel, V pre-transposed in kernel
|
# * K pre-transposed in kernel, V pre-transposed in kernel
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
@ -7,7 +7,6 @@ import mlx_tests
|
|||||||
|
|
||||||
|
|
||||||
class TestMetal(mlx_tests.MLXTestCase):
|
class TestMetal(mlx_tests.MLXTestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_memory_info(self):
|
def test_memory_info(self):
|
||||||
old_limit = mx.metal.set_cache_limit(0)
|
old_limit = mx.metal.set_cache_limit(0)
|
||||||
|
@ -55,9 +55,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
m.apply_to_modules(assert_training)
|
m.apply_to_modules(assert_training)
|
||||||
|
|
||||||
def test_module_attributes(self):
|
def test_module_attributes(self):
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.val = None
|
self.val = None
|
||||||
@ -806,6 +804,15 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, (3,))
|
self.assertEqual(y.shape, (3,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_softmin(self):
|
||||||
|
x = mx.array([1.0, 2.0, 3.0])
|
||||||
|
y = nn.softmin(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([0.6652, 0.2447, 0.0900])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, (3,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_softplus(self):
|
def test_softplus(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.softplus(x)
|
y = nn.softplus(x)
|
||||||
@ -899,6 +906,28 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
out = nn.glu(x)
|
out = nn.glu(x)
|
||||||
self.assertEqualArray(out, y)
|
self.assertEqualArray(out, y)
|
||||||
|
|
||||||
|
def test_hard_tanh(self):
|
||||||
|
x = mx.array([1.0, -2.0, 0.0, 0.5, 2.0])
|
||||||
|
y = nn.hard_tanh(x)
|
||||||
|
expected_y = mx.array([1.0, -1.0, 0.0, 0.5, 1.0])
|
||||||
|
self.assertTrue(mx.array_equal(y, expected_y))
|
||||||
|
self.assertEqual(y.shape, (5,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_hard_shrink(self):
|
||||||
|
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
||||||
|
y = nn.hard_shrink(x)
|
||||||
|
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
||||||
|
self.assertTrue(mx.array_equal(y, expected_y))
|
||||||
|
self.assertEqual(y.shape, (5,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
y = nn.hard_shrink(x, lambd=1.0)
|
||||||
|
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
|
||||||
|
self.assertTrue(mx.array_equal(y, expected_y))
|
||||||
|
self.assertEqual(y.shape, (5,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_rope(self):
|
def test_rope(self):
|
||||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||||
rope = nn.RoPE(4, **kwargs)
|
rope = nn.RoPE(4, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user