mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Merge b3c1aaafd2
into cad5c0241c
This commit is contained in:
commit
7170e5f40b
@ -224,6 +224,13 @@ def relu6(x):
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def relu_squared(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.relu_squared(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
@ -458,6 +465,9 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "relu_squared":
|
||||
print(bench(relu_squared, x))
|
||||
|
||||
elif args.benchmark == "celu":
|
||||
print(bench(celu, x))
|
||||
|
||||
|
@ -157,6 +157,15 @@ def relu6(x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def relu_squared(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.relu(y)
|
||||
y = torch.square(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softplus(x):
|
||||
y = x
|
||||
@ -407,6 +416,9 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "relu_squared":
|
||||
print(bench(relu_squared, x))
|
||||
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
|
||||
|
@ -207,6 +207,8 @@ if __name__ == "__main__":
|
||||
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||
compare_filtered("relu6 --size 32x16x1024")
|
||||
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||
compare_filtered("relu_squared --size 32x16x1024")
|
||||
compare_filtered("relu_squared --size 32x16x1024 --cpu")
|
||||
compare_filtered("softplus --size 32x16x1024")
|
||||
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||
compare_filtered("celu --size 32x16x1024")
|
||||
|
@ -28,6 +28,7 @@ simple functions.
|
||||
prelu
|
||||
relu
|
||||
relu6
|
||||
relu_squared
|
||||
selu
|
||||
sigmoid
|
||||
silu
|
||||
|
@ -51,6 +51,7 @@ Layers
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU6
|
||||
ReLUSquared
|
||||
RNN
|
||||
RoPE
|
||||
SELU
|
||||
|
@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
|
||||
PReLU,
|
||||
ReLU,
|
||||
ReLU6,
|
||||
ReLUSquared,
|
||||
Sigmoid,
|
||||
SiLU,
|
||||
Softmax,
|
||||
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
|
||||
prelu,
|
||||
relu,
|
||||
relu6,
|
||||
relu_squared,
|
||||
selu,
|
||||
sigmoid,
|
||||
silu,
|
||||
|
@ -71,6 +71,17 @@ def relu6(x):
|
||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu_squared(x):
|
||||
r"""Applies the Rectified Linear Unit squared.
|
||||
|
||||
Applies :math:`\max(x, 0)^2` element wise.
|
||||
|
||||
Reference: https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
return relu(x).square()
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def softmax(x, axis=-1):
|
||||
r"""Applies the Softmax function.
|
||||
@ -420,6 +431,18 @@ class ReLU6(Module):
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(relu_squared)
|
||||
class ReLUSquared(Module):
|
||||
r"""Applies the Rectified Linear Unit squared.
|
||||
|
||||
Applies :math:`\max(x, 0)^2` element wise.
|
||||
|
||||
Reference: https://arxiv.org/abs/2109.08668v2
|
||||
|
||||
See :func:`relu_squared` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(softmax)
|
||||
class Softmax(Module):
|
||||
r"""Applies the Softmax function.
|
||||
|
@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_relu_squared(self):
|
||||
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
|
||||
y = nn.relu_squared(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_leaky_relu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.leaky_relu(x)
|
||||
|
Loading…
Reference in New Issue
Block a user