mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
Add softmin, hardshrink, hardtanh (#1180)
--------- Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
@@ -950,7 +950,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
lhs_indices=lhs_indices,
|
||||
rhs_indices=rhs_indices,
|
||||
):
|
||||
|
||||
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
|
||||
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
|
||||
|
||||
@@ -1066,7 +1065,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
def test_gather_matmul_grad(self):
|
||||
|
||||
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
||||
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user