mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
Add softmin, hardshrink, hardtanh (#1180)
--------- Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
@@ -382,7 +382,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||
|
||||
def test_compile_kwargs(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y, z):
|
||||
return x + y + z
|
||||
@@ -479,7 +478,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||
|
||||
def test_compile_with_constant(self):
|
||||
|
||||
# Test float
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
@@ -582,7 +580,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
def test_compile_inf(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return mx.isinf(x + 2)
|
||||
@@ -591,7 +588,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(out.item(), False)
|
||||
|
||||
def test_unsupported_input_types(self):
|
||||
|
||||
class MyClass:
|
||||
value = 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user