Compare commits

...

7 Commits

43 changed files with 278 additions and 35 deletions

View File

@ -234,6 +234,7 @@ jobs:
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
build_release:
parameters:

View File

@ -224,6 +224,13 @@ def relu6(x):
mx.eval(y)
def relu_squared(x):
y = x
for i in range(100):
y = nn.relu_squared(y)
mx.eval(y)
def softplus(x):
y = x
for i in range(100):
@ -458,6 +465,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "celu":
print(bench(celu, x))

View File

@ -157,6 +157,15 @@ def relu6(x):
sync_if_needed(x)
@torch.no_grad()
def relu_squared(x):
y = x
for i in range(100):
y = torch.nn.functional.relu(y)
y = torch.square(y)
sync_if_needed(x)
@torch.no_grad()
def softplus(x):
y = x
@ -407,6 +416,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))

View File

@ -207,6 +207,8 @@ if __name__ == "__main__":
compare_filtered("elu --size 32x16x1024 --cpu")
compare_filtered("relu6 --size 32x16x1024")
compare_filtered("relu6 --size 32x16x1024 --cpu")
compare_filtered("relu_squared --size 32x16x1024")
compare_filtered("relu_squared --size 32x16x1024 --cpu")
compare_filtered("softplus --size 32x16x1024")
compare_filtered("softplus --size 32x16x1024 --cpu")
compare_filtered("celu --size 32x16x1024")

View File

@ -28,6 +28,7 @@ simple functions.
prelu
relu
relu6
relu_squared
selu
sigmoid
silu

View File

@ -51,6 +51,7 @@ Layers
RMSNorm
ReLU
ReLU6
ReLUSquared
RNN
RoPE
SELU

View File

@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
PReLU,
ReLU,
ReLU6,
ReLUSquared,
Sigmoid,
SiLU,
Softmax,
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
prelu,
relu,
relu6,
relu_squared,
selu,
sigmoid,
silu,

View File

@ -71,6 +71,17 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def relu_squared(x):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
"""
return relu(x).square()
@partial(mx.compile, shapeless=True)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
@ -420,6 +431,18 @@ class ReLU6(Module):
"""
@_make_activation_module(relu_squared)
class ReLUSquared(Module):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
See :func:`relu_squared` for the functional equivalent.
"""
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.

5
python/tests/__main__.py Normal file
View File

@ -0,0 +1,5 @@
from . import mlx_tests
__unittest = True
mlx_tests.MLXTestRunner(module=None)

143
python/tests/cuda_skip.py Normal file
View File

@ -0,0 +1,143 @@
cuda_skip = {
"TestArray.test_api",
"TestArray.test_setitem",
"TestAutograd.test_cumprod_grad",
"TestAutograd.test_slice_grads",
"TestAutograd.test_split_against_slice",
"TestAutograd.test_stop_gradient",
"TestAutograd.test_topk_grad",
"TestAutograd.test_update_state",
"TestAutograd.test_vjp",
"TestBF16.test_arg_reduction_ops",
"TestBF16.test_binary_ops",
"TestBF16.test_reduction_ops",
"TestBlas.test_block_masked_matmul",
"TestBlas.test_complex_gemm",
"TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad",
"TestBlas.test_matmul_batched",
"TestBlas.test_matrix_vector_attn",
"TestCompile.test_compile_dynamic_dims",
"TestCompile.test_compile_inf",
"TestCompile.test_inf_constant",
"TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding",
"TestConv.test_basic_grad_shapes",
"TestConv.test_conv2d_unaligned_channels",
"TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad",
"TestConv.test_numpy_conv",
"TestConv.test_repeated_conv",
"TestConv.test_torch_conv_1D",
"TestConv.test_torch_conv_1D_grad",
"TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_2D_grad",
"TestConv.test_torch_conv_3D",
"TestConv.test_torch_conv_3D_grad",
"TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general",
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_1D",
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
"TestConvTranspose.test_torch_conv_transpose_2D",
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
"TestEinsum.test_attention",
"TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
"TestEval.test_multi_output_eval_during_transform",
"TestExportImport.test_export_conv",
"TestFast.test_rope_grad",
"TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two",
"TestFFT.test_fft_contiguity",
"TestFFT.test_fft_exhaustive",
"TestFFT.test_fft_grads",
"TestFFT.test_fft_into_ifft",
"TestFFT.test_fft_large_numbers",
"TestFFT.test_fft_shared_mem",
"TestFFT.test_fftn",
"TestInit.test_orthogonal",
"TestLinalg.test_cholesky",
"TestLinalg.test_cholesky_inv",
"TestLinalg.test_eig",
"TestLinalg.test_eigh",
"TestLinalg.test_inverse",
"TestLinalg.test_lu",
"TestLinalg.test_lu_factor",
"TestLinalg.test_pseudo_inverse",
"TestLinalg.test_qr_factorization",
"TestLinalg.test_svd_decomposition",
"TestLinalg.test_tri_inverse",
"TestLoad.test_load_f8_e4m3",
"TestLosses.test_binary_cross_entropy",
"TestMemory.test_memory_info",
"TestLayers.test_conv1d",
"TestLayers.test_conv2d",
"TestLayers.test_elu",
"TestLayers.test_group_norm",
"TestLayers.test_hard_shrink",
"TestLayers.test_pooling",
"TestLayers.test_quantized_embedding",
"TestLayers.test_sin_pe",
"TestLayers.test_softshrink",
"TestLayers.test_upsample",
"TestOps.test_argpartition",
"TestOps.test_array_equal",
"TestOps.test_as_strided",
"TestOps.test_atleast_1d",
"TestOps.test_atleast_2d",
"TestOps.test_atleast_3d",
"TestOps.test_binary_ops",
"TestOps.test_bitwise_grad",
"TestOps.test_complex_ops",
"TestOps.test_divmod",
"TestOps.test_dynamic_slicing",
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",
"TestOps.test_irregular_binary_ops",
"TestOps.test_isfinite",
"TestOps.test_kron",
"TestOps.test_log",
"TestOps.test_log10",
"TestOps.test_log1p",
"TestOps.test_log2",
"TestOps.test_logaddexp",
"TestOps.test_logcumsumexp",
"TestOps.test_partition",
"TestOps.test_scans",
"TestOps.test_slice_update_reversed",
"TestOps.test_softmax",
"TestOps.test_sort",
"TestOps.test_tensordot",
"TestOps.test_tile",
"TestOps.test_view",
"TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted",
"TestQuantized.test_non_multiples",
"TestQuantized.test_qmm",
"TestQuantized.test_qmm_jvp",
"TestQuantized.test_qmm_shapes",
"TestQuantized.test_qmm_vjp",
"TestQuantized.test_qmv",
"TestQuantized.test_quantize_dequantize",
"TestQuantized.test_qvm",
"TestQuantized.test_qvm_splitk",
"TestQuantized.test_small_matrix",
"TestQuantized.test_throw",
"TestQuantized.test_vjp_scales_biases",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes",
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
"TestVmap.test_unary",
"TestVmap.test_vmap_conv",
"TestVmap.test_vmap_inverse",
"TestVmap.test_vmap_svd",
}

View File

@ -9,6 +9,42 @@ import mlx.core as mx
import numpy as np
class MLXTestRunner(unittest.TestProgram):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def createTests(self, *args, **kwargs):
super().createTests(*args, **kwargs)
# Asume CUDA backend in this case
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
else:
device = mx.default_device()
if not (device == mx.gpu and not mx.metal.is_available()):
return
from cuda_skip import cuda_skip
filtered_suite = unittest.TestSuite()
def filter_and_add(t):
if isinstance(t, unittest.TestSuite):
for sub_t in t:
filter_and_add(sub_t)
else:
t_id = ".".join(t.id().split(".")[-2:])
if t_id in cuda_skip:
print(f"Skipping {t_id}")
else:
filtered_suite.addTest(t)
filter_and_add(self.test)
self.test = filtered_suite
class MLXTestCase(unittest.TestCase):
@property
def is_apple_silicon(self):

View File

@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase):
# Restore device
mx.set_default_device(device)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_device_context(self):
default = mx.default_device()
diff = mx.cpu if default == mx.gpu else mx.gpu
@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase):
post = mx.get_peak_memory()
self.assertEqual(pre, post)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_multistream_deadlock(self):
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_prommote_mask(self):
def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64
Nq = 4
@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main(failfast=True)
mlx_tests.MLXTestRunner(failfast=True)

View File

@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_relu_squared(self):
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
y = nn.relu_squared(x)
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_leaky_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.leaky_relu(x)
@ -1907,4 +1914,4 @@ class TestLayers(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -3127,4 +3127,4 @@ class TestBroadcast(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main(failfast=True)
mlx_tests.MLXTestRunner(failfast=True)

View File

@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()