Add softmin, hardshrink, hardtanh (#1180)

---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
Nikhil Mehta
2024-06-04 15:48:18 -07:00
committed by GitHub
parent 83b11bc58d
commit 0b7d71fd2f
14 changed files with 110 additions and 20 deletions

View File

@@ -950,7 +950,6 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices=lhs_indices,
rhs_indices=rhs_indices,
):
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
@@ -1066,7 +1065,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
def test_gather_matmul_grad(self):
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)

View File

@@ -382,7 +382,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
def test_compile_kwargs(self):
@mx.compile
def fun(x, y, z):
return x + y + z
@@ -479,7 +478,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
def fun(x, y):
@@ -582,7 +580,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(counter[0], 2)
def test_compile_inf(self):
@mx.compile
def fun(x):
return mx.isinf(x + 2)
@@ -591,7 +588,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(out.item(), False)
def test_unsupported_input_types(self):
class MyClass:
value = 1

View File

@@ -672,7 +672,6 @@ class TestConv(mlx_tests.MLXTestCase):
np_dtype=np.float32,
atol=1e-5,
):
with self.subTest(
in_shape=in_shape,
wt_shape=wt_shape,
@@ -684,7 +683,6 @@ class TestConv(mlx_tests.MLXTestCase):
flip=flip,
np_dtype=np_dtype,
):
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
@@ -710,7 +708,6 @@ class TestConv(mlx_tests.MLXTestCase):
def conv_general_pt(
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
):
C = inp.size()[1]
ndim = inp.ndim - 2
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x

View File

@@ -15,7 +15,6 @@ def mlx_primitives_sdpa(q, k, v, scale):
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
n_repeats = q.shape[1] // k.shape[1]
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
@@ -34,7 +33,6 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
def test_fast_sdpa(self):
# Not yet supported:
# * K pre-transposed in kernel, V pre-transposed in kernel
np.random.seed(0)

View File

@@ -7,7 +7,6 @@ import mlx_tests
class TestMetal(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.metal.set_cache_limit(0)

View File

@@ -55,9 +55,7 @@ class TestBase(mlx_tests.MLXTestCase):
m.apply_to_modules(assert_training)
def test_module_attributes(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.val = None
@@ -806,6 +804,15 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softmin(self):
x = mx.array([1.0, 2.0, 3.0])
y = nn.softmin(x)
epsilon = 1e-4
expected_y = mx.array([0.6652, 0.2447, 0.0900])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softplus(x)
@@ -899,6 +906,28 @@ class TestLayers(mlx_tests.MLXTestCase):
out = nn.glu(x)
self.assertEqualArray(out, y)
def test_hard_tanh(self):
x = mx.array([1.0, -2.0, 0.0, 0.5, 2.0])
y = nn.hard_tanh(x)
expected_y = mx.array([1.0, -1.0, 0.0, 0.5, 1.0])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_hard_shrink(self):
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
y = nn.hard_shrink(x)
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
y = nn.hard_shrink(x, lambd=1.0)
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
rope = nn.RoPE(4, **kwargs)