mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add softmin, hardshrink, hardtanh (#1180)
--------- Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
parent
83b11bc58d
commit
0b7d71fd2f
@ -9,7 +9,6 @@ from time_utils import time_fn
|
||||
|
||||
|
||||
def bench_gelu():
|
||||
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
@ -51,7 +50,6 @@ def bench_gelu():
|
||||
|
||||
|
||||
def bench_layernorm():
|
||||
|
||||
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
mx.eval(weight, bias)
|
||||
|
@ -54,7 +54,6 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
|
@ -35,7 +35,6 @@ def run_bench(system_size):
|
||||
|
||||
|
||||
def time_fft():
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||
|
||||
|
@ -10,7 +10,6 @@ SEQ_INCREMENT = 50
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
@ -32,7 +31,6 @@ def time_self_attention_primitives():
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
|
@ -17,6 +17,8 @@ simple functions.
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
glu
|
||||
hard_shrink
|
||||
hard_tanh
|
||||
hardswish
|
||||
leaky_relu
|
||||
log_sigmoid
|
||||
@ -29,6 +31,7 @@ simple functions.
|
||||
sigmoid
|
||||
silu
|
||||
softmax
|
||||
softmin
|
||||
softplus
|
||||
softshrink
|
||||
step
|
||||
|
@ -21,10 +21,15 @@ Layers
|
||||
Dropout3d
|
||||
Embedding
|
||||
GELU
|
||||
GLU
|
||||
GroupNorm
|
||||
GRU
|
||||
HardShrink
|
||||
HardTanh
|
||||
Hardswish
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
LeakyReLU
|
||||
Linear
|
||||
LSTM
|
||||
MaxPool1d
|
||||
@ -36,13 +41,19 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
SELU
|
||||
Sequential
|
||||
SiLU
|
||||
SinusoidalPositionalEncoding
|
||||
Softmin
|
||||
Softshrink
|
||||
Softsign
|
||||
Softmax
|
||||
Softplus
|
||||
Step
|
||||
Tanh
|
||||
Transformer
|
||||
Upsample
|
||||
|
@ -6,7 +6,9 @@ from mlx.nn.layers.activations import (
|
||||
GELU,
|
||||
GLU,
|
||||
SELU,
|
||||
HardShrink,
|
||||
Hardswish,
|
||||
HardTanh,
|
||||
LeakyReLU,
|
||||
LogSigmoid,
|
||||
LogSoftmax,
|
||||
@ -17,6 +19,7 @@ from mlx.nn.layers.activations import (
|
||||
Sigmoid,
|
||||
SiLU,
|
||||
Softmax,
|
||||
Softmin,
|
||||
Softplus,
|
||||
Softshrink,
|
||||
Softsign,
|
||||
@ -28,6 +31,8 @@ from mlx.nn.layers.activations import (
|
||||
gelu_approx,
|
||||
gelu_fast_approx,
|
||||
glu,
|
||||
hard_shrink,
|
||||
hard_tanh,
|
||||
hardswish,
|
||||
leaky_relu,
|
||||
log_sigmoid,
|
||||
@ -40,6 +45,7 @@ from mlx.nn.layers.activations import (
|
||||
sigmoid,
|
||||
silu,
|
||||
softmax,
|
||||
softmin,
|
||||
softplus,
|
||||
softshrink,
|
||||
softsign,
|
||||
|
@ -286,6 +286,38 @@ def hardswish(x):
|
||||
return x * mx.minimum(max_x_3, 6) / 6
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def hard_tanh(x, min_val=-1.0, max_val=1.0):
|
||||
r"""Applies the HardTanh function.
|
||||
|
||||
Applies :math:`\max(\min(x, \text{max\_val}), \text{min\_val})` element-wise.
|
||||
"""
|
||||
return mx.minimum(mx.maximum(x, min_val), max_val)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def hard_shrink(x, lambd=0.5):
|
||||
r"""Applies the HardShrink activation function.
|
||||
|
||||
.. math::
|
||||
\text{hardshrink}(x) = \begin{cases}
|
||||
x & \text{if } x > \lambda \\
|
||||
x & \text{if } x < -\lambda \\
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
"""
|
||||
return mx.where(mx.abs(x) > lambd, x, 0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def softmin(x, axis=-1):
|
||||
r"""Applies the Softmin function.
|
||||
|
||||
Applies :math:`\frac{e^{-x_i}}{\sum_j e^{-x_j}}` element-wise.
|
||||
"""
|
||||
return mx.softmax(-x, axis=axis)
|
||||
|
||||
|
||||
def tanh(x):
|
||||
"""Applies the hyperbolic tangent function.
|
||||
|
||||
@ -579,3 +611,30 @@ class SELU(Module):
|
||||
|
||||
See :func:`selu` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(hard_tanh)
|
||||
class HardTanh(Module):
|
||||
r"""Applies the HardTanh function.
|
||||
|
||||
See :func:`hard_tanh` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(hard_shrink)
|
||||
class HardShrink(Module):
|
||||
r"""Applies the HardShrink function.
|
||||
|
||||
See :func:`hard_shrink` for the functional equivalent.
|
||||
|
||||
Args:
|
||||
lambd: the :math:`\lambda` value for Hardshrink. Default: ``0.5``
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(softmin)
|
||||
class Softmin(Module):
|
||||
r"""Applies the Softmin function.
|
||||
|
||||
See :func:`softmin` for the functional equivalent.
|
||||
"""
|
||||
|
@ -950,7 +950,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
lhs_indices=lhs_indices,
|
||||
rhs_indices=rhs_indices,
|
||||
):
|
||||
|
||||
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
|
||||
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
|
||||
|
||||
@ -1066,7 +1065,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
def test_gather_matmul_grad(self):
|
||||
|
||||
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
||||
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||
|
||||
|
@ -382,7 +382,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||
|
||||
def test_compile_kwargs(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y, z):
|
||||
return x + y + z
|
||||
@ -479,7 +478,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||
|
||||
def test_compile_with_constant(self):
|
||||
|
||||
# Test float
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
@ -582,7 +580,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
def test_compile_inf(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return mx.isinf(x + 2)
|
||||
@ -591,7 +588,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(out.item(), False)
|
||||
|
||||
def test_unsupported_input_types(self):
|
||||
|
||||
class MyClass:
|
||||
value = 1
|
||||
|
||||
|
@ -672,7 +672,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
np_dtype=np.float32,
|
||||
atol=1e-5,
|
||||
):
|
||||
|
||||
with self.subTest(
|
||||
in_shape=in_shape,
|
||||
wt_shape=wt_shape,
|
||||
@ -684,7 +683,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
flip=flip,
|
||||
np_dtype=np_dtype,
|
||||
):
|
||||
|
||||
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
||||
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
|
||||
@ -710,7 +708,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
def conv_general_pt(
|
||||
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
|
||||
):
|
||||
|
||||
C = inp.size()[1]
|
||||
ndim = inp.ndim - 2
|
||||
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
|
||||
|
@ -15,7 +15,6 @@ def mlx_primitives_sdpa(q, k, v, scale):
|
||||
|
||||
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
|
||||
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||
|
||||
n_repeats = q.shape[1] // k.shape[1]
|
||||
|
||||
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
|
||||
@ -34,7 +33,6 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||
|
||||
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
def test_fast_sdpa(self):
|
||||
|
||||
# Not yet supported:
|
||||
# * K pre-transposed in kernel, V pre-transposed in kernel
|
||||
np.random.seed(0)
|
||||
|
@ -7,7 +7,6 @@ import mlx_tests
|
||||
|
||||
|
||||
class TestMetal(mlx_tests.MLXTestCase):
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.metal.set_cache_limit(0)
|
||||
|
@ -55,9 +55,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
m.apply_to_modules(assert_training)
|
||||
|
||||
def test_module_attributes(self):
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.val = None
|
||||
@ -806,6 +804,15 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softmin(self):
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
y = nn.softmin(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.6652, 0.2447, 0.0900])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softplus(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softplus(x)
|
||||
@ -899,6 +906,28 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
out = nn.glu(x)
|
||||
self.assertEqualArray(out, y)
|
||||
|
||||
def test_hard_tanh(self):
|
||||
x = mx.array([1.0, -2.0, 0.0, 0.5, 2.0])
|
||||
y = nn.hard_tanh(x)
|
||||
expected_y = mx.array([1.0, -1.0, 0.0, 0.5, 1.0])
|
||||
self.assertTrue(mx.array_equal(y, expected_y))
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_hard_shrink(self):
|
||||
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
||||
y = nn.hard_shrink(x)
|
||||
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
||||
self.assertTrue(mx.array_equal(y, expected_y))
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.hard_shrink(x, lambd=1.0)
|
||||
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
|
||||
self.assertTrue(mx.array_equal(y, expected_y))
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_rope(self):
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||
rope = nn.RoPE(4, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user