Add softmin, hardshrink, hardtanh (#1180)

---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
Nikhil Mehta
2024-06-04 15:48:18 -07:00
committed by GitHub
parent 83b11bc58d
commit 0b7d71fd2f
14 changed files with 110 additions and 20 deletions

View File

@@ -672,7 +672,6 @@ class TestConv(mlx_tests.MLXTestCase):
np_dtype=np.float32,
atol=1e-5,
):
with self.subTest(
in_shape=in_shape,
wt_shape=wt_shape,
@@ -684,7 +683,6 @@ class TestConv(mlx_tests.MLXTestCase):
flip=flip,
np_dtype=np_dtype,
):
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
@@ -710,7 +708,6 @@ class TestConv(mlx_tests.MLXTestCase):
def conv_general_pt(
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
):
C = inp.size()[1]
ndim = inp.ndim - 2
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x