Add softmin, hardshrink, hardtanh (#1180)

---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
Nikhil Mehta
2024-06-04 15:48:18 -07:00
committed by GitHub
parent 83b11bc58d
commit 0b7d71fd2f
14 changed files with 110 additions and 20 deletions

View File

@@ -950,7 +950,6 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices=lhs_indices,
rhs_indices=rhs_indices,
):
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
@@ -1066,7 +1065,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
def test_gather_matmul_grad(self):
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)