mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
format
This commit is contained in:
parent
3713832e5e
commit
a315af8981
@ -6,6 +6,7 @@ from mlx.nn.layers.activations import (
|
||||
GELU,
|
||||
GLU,
|
||||
SELU,
|
||||
ClippedSiLU,
|
||||
HardShrink,
|
||||
Hardswish,
|
||||
HardTanh,
|
||||
@ -18,7 +19,6 @@ from mlx.nn.layers.activations import (
|
||||
ReLU6,
|
||||
Sigmoid,
|
||||
SiLU,
|
||||
ClippedSiLU,
|
||||
Softmax,
|
||||
Softmin,
|
||||
Softplus,
|
||||
@ -27,6 +27,7 @@ from mlx.nn.layers.activations import (
|
||||
Step,
|
||||
Tanh,
|
||||
celu,
|
||||
clipped_silu,
|
||||
elu,
|
||||
gelu,
|
||||
gelu_approx,
|
||||
@ -45,7 +46,6 @@ from mlx.nn.layers.activations import (
|
||||
selu,
|
||||
sigmoid,
|
||||
silu,
|
||||
clipped_silu,
|
||||
softmax,
|
||||
softmin,
|
||||
softplus,
|
||||
|
@ -955,7 +955,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
|
||||
def test_silu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.silu(x)
|
||||
@ -983,12 +983,12 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
|
||||
x_extreme = mx.array([200.0, -200.0])
|
||||
y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50)
|
||||
expected_clipped = mx.array([50.0, -50.0])
|
||||
self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon))
|
||||
|
||||
|
||||
y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme)
|
||||
expected_custom = mx.array([10.0, -10.0])
|
||||
self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))
|
||||
|
Loading…
Reference in New Issue
Block a user