Add all tests except the custom caching

This commit is contained in:
Angelos Katharopoulos 2025-08-18 23:45:13 -07:00
parent 14efd9c35a
commit bffadc2cb9
4 changed files with 132 additions and 46 deletions

View File

@ -17,6 +17,7 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp

19
python/src/cuda.cpp Normal file
View File

@ -0,0 +1,19 @@
// Copyright © 2023-2025 Apple Inc.
#include <nanobind/nanobind.h>
#include "mlx/backend/cuda/cuda.h"
namespace mx = mlx::core;
namespace nb = nanobind;
void init_cuda(nb::module_& m) {
nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda");
cuda.def(
"is_available",
&mx::cu::is_available,
R"pbdoc(
Check if the CUDA back-end is available.
)pbdoc");
}

View File

@ -12,6 +12,7 @@ void init_array(nb::module_&);
void init_device(nb::module_&);
void init_stream(nb::module_&);
void init_metal(nb::module_&);
void init_cuda(nb::module_&);
void init_memory(nb::module_&);
void init_ops(nb::module_&);
void init_transforms(nb::module_&);
@ -35,6 +36,7 @@ NB_MODULE(core, m) {
init_stream(m);
init_array(m);
init_metal(m);
init_cuda(m);
init_memory(m);
init_ops(m);
init_transforms(m);

View File

@ -581,18 +581,28 @@ class TestFast(mlx_tests.MLXTestCase):
)(x)
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_basic(self):
if mx.metal.is_available():
source = """
uint elem thread_position_in_grid.x;
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
kernel = custom_kernel(
name="basic",
input_names=["a"],
output_names=["out1"],
source="""
uint elem = thread_position_in_grid.x;
out1[elem] = a[elem];
""",
source=source,
)
out = kernel(
inputs=[a],
@ -604,17 +614,10 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(out[0], a))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_args(self):
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
kernel = mx.fast.metal_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source="""
if mx.metal.is_available():
source = """
uint elem = thread_position_in_grid.x;
T tmp = a[0];
if (e) {
@ -623,7 +626,30 @@ class TestFast(mlx_tests.MLXTestCase):
out1[elem] = 1;
}
out2[elem] = a[1] + b[2] + c[1] - d;
""",
"""
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = a[0];
if (e) {
out1[elem] = a[1] + b[2] + static_cast<float>(c[3]) + d[0] + f;
} else {
out1[elem] = 1;
}
out2[elem] = a[1] + b[2] + static_cast<float>(c[1]) - d[0];
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
kernel = custom_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source=source,
)
out = kernel(
inputs=[
@ -647,27 +673,43 @@ class TestFast(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_strides(self):
if mx.metal.is_available():
source = """
uint elem = thread_position_in_grid.x;
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
"""
source_contig = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
"""
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim);
T tmp = inp[loc];
out[elem] = exp(tmp) * WARP_SIZE;
"""
source_contig = """
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = inp[elem];
out[elem] = exp(tmp) * WARP_SIZE;
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
source = """
uint elem = thread_position_in_grid.x;
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
"""
source_contig = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
"""
# non contiguous
a = mx.tile(a[::2], [4, 1])
for contig in [True, False]:
kernel = mx.fast.metal_kernel(
kernel = custom_kernel(
name="myexp" + str(contig),
input_names=["inp"],
output_names=["out"],
@ -685,24 +727,41 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_helper(self):
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header="""
if mx.metal.is_available():
header = """
template <typename T>
T do_exp(T x) {
return metal::precise::exp(x);
}
""",
source="""
"""
source = """
uint elem = thread_position_in_grid.x;
out1[elem] = do_exp(a[elem]);
""",
"""
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
header = """
template <typename T>
__device__ T do_exp(T x) {
return exp(x);
}
"""
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = do_exp(a[elem]);
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = custom_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header=header,
source=source,
)
out = kernel(
inputs=[a],
@ -714,16 +773,21 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_attributes(self):
if mx.metal.is_available():
source = "out[0] = threads_per_threadgroup.x;"
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = "out[0] = blockDim.x;"
custom_kernel = mx.fast.cuda_kernel
a = mx.zeros(shape=(1, 1))
kernel = mx.fast.metal_kernel(
kernel = custom_kernel(
name="test_fun",
input_names=["a"],
output_names=["out"],
source="""
out[0] = threads_per_threadgroup.x;
""",
source=source,
)
out = kernel(
inputs=[a],