Add softmin, hardshrink, hardtanh (#1180)

---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
Nikhil Mehta
2024-06-04 15:48:18 -07:00
committed by GitHub
parent 83b11bc58d
commit 0b7d71fd2f
14 changed files with 110 additions and 20 deletions

View File

@@ -382,7 +382,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
def test_compile_kwargs(self):
@mx.compile
def fun(x, y, z):
return x + y + z
@@ -479,7 +478,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
def fun(x, y):
@@ -582,7 +580,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(counter[0], 2)
def test_compile_inf(self):
@mx.compile
def fun(x):
return mx.isinf(x + 2)
@@ -591,7 +588,6 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(out.item(), False)
def test_unsupported_input_types(self):
class MyClass:
value = 1