mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
format
This commit is contained in:
parent
3713832e5e
commit
a315af8981
@ -6,6 +6,7 @@ from mlx.nn.layers.activations import (
|
|||||||
GELU,
|
GELU,
|
||||||
GLU,
|
GLU,
|
||||||
SELU,
|
SELU,
|
||||||
|
ClippedSiLU,
|
||||||
HardShrink,
|
HardShrink,
|
||||||
Hardswish,
|
Hardswish,
|
||||||
HardTanh,
|
HardTanh,
|
||||||
@ -18,7 +19,6 @@ from mlx.nn.layers.activations import (
|
|||||||
ReLU6,
|
ReLU6,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
SiLU,
|
SiLU,
|
||||||
ClippedSiLU,
|
|
||||||
Softmax,
|
Softmax,
|
||||||
Softmin,
|
Softmin,
|
||||||
Softplus,
|
Softplus,
|
||||||
@ -27,6 +27,7 @@ from mlx.nn.layers.activations import (
|
|||||||
Step,
|
Step,
|
||||||
Tanh,
|
Tanh,
|
||||||
celu,
|
celu,
|
||||||
|
clipped_silu,
|
||||||
elu,
|
elu,
|
||||||
gelu,
|
gelu,
|
||||||
gelu_approx,
|
gelu_approx,
|
||||||
@ -45,7 +46,6 @@ from mlx.nn.layers.activations import (
|
|||||||
selu,
|
selu,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
silu,
|
silu,
|
||||||
clipped_silu,
|
|
||||||
softmax,
|
softmax,
|
||||||
softmin,
|
softmin,
|
||||||
softplus,
|
softplus,
|
||||||
|
@ -955,7 +955,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
self.assertEqual(y.shape, (3,))
|
self.assertEqual(y.shape, (3,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_silu(self):
|
def test_silu(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.silu(x)
|
y = nn.silu(x)
|
||||||
@ -983,12 +983,12 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
self.assertEqual(y.shape, (3,))
|
self.assertEqual(y.shape, (3,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
x_extreme = mx.array([200.0, -200.0])
|
x_extreme = mx.array([200.0, -200.0])
|
||||||
y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50)
|
y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50)
|
||||||
expected_clipped = mx.array([50.0, -50.0])
|
expected_clipped = mx.array([50.0, -50.0])
|
||||||
self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon))
|
self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon))
|
||||||
|
|
||||||
y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme)
|
y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme)
|
||||||
expected_custom = mx.array([10.0, -10.0])
|
expected_custom = mx.array([10.0, -10.0])
|
||||||
self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))
|
self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))
|
||||||
|
Loading…
Reference in New Issue
Block a user