mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
3 Commits
3eaa743030
...
5e5d379a2d
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5e5d379a2d | ||
![]() |
4fda5fbdf9 | ||
![]() |
7c99acb799 |
@ -234,6 +234,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
|
@ -5,28 +5,33 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint2 tid [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
int lid = _lid.x;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
constexpr int elem_per_group = SIMD_SIZE * 32 * N_READS;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * size_t(axis_size) + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
const int axis_offset = tid.y * elem_per_group;
|
||||
in += gid.x * size_t(axis_size) + lid * N_READS + axis_offset;
|
||||
if (axis_offset + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
|
||||
ld[i] = ((axis_offset + lid * N_READS + i) < axis_size)
|
||||
? AccT(in[i])
|
||||
: Limits<AccT>::min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
@ -55,6 +60,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
out += gid.x * grid_dim.y + tid.y;
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += fast::exp(ld[i] - maxval);
|
||||
@ -67,7 +73,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
out[0] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -62,15 +62,37 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const int n_reads = 4;
|
||||
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
|
||||
|
||||
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
|
||||
bool split = n_rows < 4 && axis_size > 4 * looped_limit;
|
||||
bool looped = !split && axis_size > looped_limit;
|
||||
std::string kernel_name = looped ? "looped_" : "block_";
|
||||
kernel_name += "logsumexp_";
|
||||
kernel_name += type_to_name(out);
|
||||
|
||||
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
if (split) {
|
||||
auto tmp_size = ceildiv(axis_size, looped_limit);
|
||||
auto tmp_shape = Shape{n_rows, static_cast<int>(tmp_size)};
|
||||
array tmp(tmp_shape, in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
size_t threadgroup_size = 1024;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
auto grid_dims = MTL::Size(n_threads, tmp_size, 1);
|
||||
auto group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(tmp, 1);
|
||||
compute_encoder.set_bytes(axis_size, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
d.add_temporary(tmp, s.index);
|
||||
in = tmp;
|
||||
axis_size = tmp_size;
|
||||
}
|
||||
|
||||
{
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
if (!looped) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
|
5
python/tests/__main__.py
Normal file
5
python/tests/__main__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from . import mlx_tests
|
||||
|
||||
__unittest = True
|
||||
|
||||
mlx_tests.MLXTestRunner(module=None)
|
143
python/tests/cuda_skip.py
Normal file
143
python/tests/cuda_skip.py
Normal file
@ -0,0 +1,143 @@
|
||||
cuda_skip = {
|
||||
"TestArray.test_api",
|
||||
"TestArray.test_setitem",
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
"TestAutograd.test_slice_grads",
|
||||
"TestAutograd.test_split_against_slice",
|
||||
"TestAutograd.test_stop_gradient",
|
||||
"TestAutograd.test_topk_grad",
|
||||
"TestAutograd.test_update_state",
|
||||
"TestAutograd.test_vjp",
|
||||
"TestBF16.test_arg_reduction_ops",
|
||||
"TestBF16.test_binary_ops",
|
||||
"TestBF16.test_reduction_ops",
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
"TestBlas.test_complex_gemm",
|
||||
"TestBlas.test_gather_matmul",
|
||||
"TestBlas.test_gather_matmul_grad",
|
||||
"TestBlas.test_matmul_batched",
|
||||
"TestBlas.test_matrix_vector_attn",
|
||||
"TestCompile.test_compile_dynamic_dims",
|
||||
"TestCompile.test_compile_inf",
|
||||
"TestCompile.test_inf_constant",
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_asymmetric_padding",
|
||||
"TestConv.test_basic_grad_shapes",
|
||||
"TestConv.test_conv2d_unaligned_channels",
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
"TestConv.test_conv_groups_grad",
|
||||
"TestConv.test_numpy_conv",
|
||||
"TestConv.test_repeated_conv",
|
||||
"TestConv.test_torch_conv_1D",
|
||||
"TestConv.test_torch_conv_1D_grad",
|
||||
"TestConv.test_torch_conv_2D",
|
||||
"TestConv.test_torch_conv_2D_grad",
|
||||
"TestConv.test_torch_conv_3D",
|
||||
"TestConv.test_torch_conv_3D_grad",
|
||||
"TestConv.test_torch_conv_depthwise",
|
||||
"TestConv.test_torch_conv_general",
|
||||
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||
"TestEinsum.test_attention",
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestEval.test_multi_output_eval_during_transform",
|
||||
"TestExportImport.test_export_conv",
|
||||
"TestFast.test_rope_grad",
|
||||
"TestFFT.test_fft",
|
||||
"TestFFT.test_fft_big_powers_of_two",
|
||||
"TestFFT.test_fft_contiguity",
|
||||
"TestFFT.test_fft_exhaustive",
|
||||
"TestFFT.test_fft_grads",
|
||||
"TestFFT.test_fft_into_ifft",
|
||||
"TestFFT.test_fft_large_numbers",
|
||||
"TestFFT.test_fft_shared_mem",
|
||||
"TestFFT.test_fftn",
|
||||
"TestInit.test_orthogonal",
|
||||
"TestLinalg.test_cholesky",
|
||||
"TestLinalg.test_cholesky_inv",
|
||||
"TestLinalg.test_eig",
|
||||
"TestLinalg.test_eigh",
|
||||
"TestLinalg.test_inverse",
|
||||
"TestLinalg.test_lu",
|
||||
"TestLinalg.test_lu_factor",
|
||||
"TestLinalg.test_pseudo_inverse",
|
||||
"TestLinalg.test_qr_factorization",
|
||||
"TestLinalg.test_svd_decomposition",
|
||||
"TestLinalg.test_tri_inverse",
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
"TestLosses.test_binary_cross_entropy",
|
||||
"TestMemory.test_memory_info",
|
||||
"TestLayers.test_conv1d",
|
||||
"TestLayers.test_conv2d",
|
||||
"TestLayers.test_elu",
|
||||
"TestLayers.test_group_norm",
|
||||
"TestLayers.test_hard_shrink",
|
||||
"TestLayers.test_pooling",
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestLayers.test_sin_pe",
|
||||
"TestLayers.test_softshrink",
|
||||
"TestLayers.test_upsample",
|
||||
"TestOps.test_argpartition",
|
||||
"TestOps.test_array_equal",
|
||||
"TestOps.test_as_strided",
|
||||
"TestOps.test_atleast_1d",
|
||||
"TestOps.test_atleast_2d",
|
||||
"TestOps.test_atleast_3d",
|
||||
"TestOps.test_binary_ops",
|
||||
"TestOps.test_bitwise_grad",
|
||||
"TestOps.test_complex_ops",
|
||||
"TestOps.test_divmod",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestOps.test_hadamard",
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
"TestOps.test_irregular_binary_ops",
|
||||
"TestOps.test_isfinite",
|
||||
"TestOps.test_kron",
|
||||
"TestOps.test_log",
|
||||
"TestOps.test_log10",
|
||||
"TestOps.test_log1p",
|
||||
"TestOps.test_log2",
|
||||
"TestOps.test_logaddexp",
|
||||
"TestOps.test_logcumsumexp",
|
||||
"TestOps.test_partition",
|
||||
"TestOps.test_scans",
|
||||
"TestOps.test_slice_update_reversed",
|
||||
"TestOps.test_softmax",
|
||||
"TestOps.test_sort",
|
||||
"TestOps.test_tensordot",
|
||||
"TestOps.test_tile",
|
||||
"TestOps.test_view",
|
||||
"TestQuantized.test_gather_matmul_grad",
|
||||
"TestQuantized.test_gather_qmm",
|
||||
"TestQuantized.test_gather_qmm_sorted",
|
||||
"TestQuantized.test_non_multiples",
|
||||
"TestQuantized.test_qmm",
|
||||
"TestQuantized.test_qmm_jvp",
|
||||
"TestQuantized.test_qmm_shapes",
|
||||
"TestQuantized.test_qmm_vjp",
|
||||
"TestQuantized.test_qmv",
|
||||
"TestQuantized.test_quantize_dequantize",
|
||||
"TestQuantized.test_qvm",
|
||||
"TestQuantized.test_qvm_splitk",
|
||||
"TestQuantized.test_small_matrix",
|
||||
"TestQuantized.test_throw",
|
||||
"TestQuantized.test_vjp_scales_biases",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_expand_sums",
|
||||
"TestReduce.test_many_reduction_axes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
"TestVmap.test_unary",
|
||||
"TestVmap.test_vmap_conv",
|
||||
"TestVmap.test_vmap_inverse",
|
||||
"TestVmap.test_vmap_svd",
|
||||
}
|
@ -9,6 +9,42 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MLXTestRunner(unittest.TestProgram):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def createTests(self, *args, **kwargs):
|
||||
super().createTests(*args, **kwargs)
|
||||
|
||||
# Asume CUDA backend in this case
|
||||
device = os.getenv("DEVICE", None)
|
||||
if device is not None:
|
||||
device = getattr(mx, device)
|
||||
else:
|
||||
device = mx.default_device()
|
||||
|
||||
if not (device == mx.gpu and not mx.metal.is_available()):
|
||||
return
|
||||
|
||||
from cuda_skip import cuda_skip
|
||||
|
||||
filtered_suite = unittest.TestSuite()
|
||||
|
||||
def filter_and_add(t):
|
||||
if isinstance(t, unittest.TestSuite):
|
||||
for sub_t in t:
|
||||
filter_and_add(sub_t)
|
||||
else:
|
||||
t_id = ".".join(t.id().split(".")[-2:])
|
||||
if t_id in cuda_skip:
|
||||
print(f"Skipping {t_id}")
|
||||
else:
|
||||
filtered_suite.addTest(t)
|
||||
|
||||
filter_and_add(self.test)
|
||||
self.test = filtered_suite
|
||||
|
||||
|
||||
class MLXTestCase(unittest.TestCase):
|
||||
@property
|
||||
def is_apple_silicon(self):
|
||||
|
@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase):
|
||||
# Restore device
|
||||
mx.set_default_device(device)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
|
||||
def test_device_context(self):
|
||||
default = mx.default_device()
|
||||
diff = mx.cpu if default == mx.gpu else mx.gpu
|
||||
@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
|
||||
def test_multistream_deadlock(self):
|
||||
s1 = mx.default_stream(mx.gpu)
|
||||
s2 = mx.new_stream(mx.gpu)
|
||||
@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_prommote_mask(self):
|
||||
def test_sdpa_promote_mask(self):
|
||||
mask = mx.array(2.0, mx.bfloat16)
|
||||
D = 64
|
||||
Nq = 4
|
||||
@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1907,4 +1907,4 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -760,6 +760,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
# Even larger
|
||||
x = mx.random.uniform(shape=(4 * 4096 + 3,))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
def test_mean(self):
|
||||
x = mx.array(
|
||||
[
|
||||
@ -3127,4 +3131,4 @@ class TestBroadcast(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Loading…
Reference in New Issue
Block a user