Add all tests except the custom caching

This commit is contained in:
Angelos Katharopoulos 2025-08-18 23:45:13 -07:00
parent 14efd9c35a
commit bffadc2cb9
4 changed files with 132 additions and 46 deletions

View File

@ -17,6 +17,7 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp

19
python/src/cuda.cpp Normal file
View File

@ -0,0 +1,19 @@
// Copyright © 2023-2025 Apple Inc.
#include <nanobind/nanobind.h>
#include "mlx/backend/cuda/cuda.h"
namespace mx = mlx::core;
namespace nb = nanobind;
void init_cuda(nb::module_& m) {
nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda");
cuda.def(
"is_available",
&mx::cu::is_available,
R"pbdoc(
Check if the CUDA back-end is available.
)pbdoc");
}

View File

@ -12,6 +12,7 @@ void init_array(nb::module_&);
void init_device(nb::module_&); void init_device(nb::module_&);
void init_stream(nb::module_&); void init_stream(nb::module_&);
void init_metal(nb::module_&); void init_metal(nb::module_&);
void init_cuda(nb::module_&);
void init_memory(nb::module_&); void init_memory(nb::module_&);
void init_ops(nb::module_&); void init_ops(nb::module_&);
void init_transforms(nb::module_&); void init_transforms(nb::module_&);
@ -35,6 +36,7 @@ NB_MODULE(core, m) {
init_stream(m); init_stream(m);
init_array(m); init_array(m);
init_metal(m); init_metal(m);
init_cuda(m);
init_memory(m); init_memory(m);
init_ops(m); init_ops(m);
init_transforms(m); init_transforms(m);

View File

@ -581,18 +581,28 @@ class TestFast(mlx_tests.MLXTestCase):
)(x) )(x)
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_basic(self): def test_custom_kernel_basic(self):
if mx.metal.is_available():
source = """
uint elem thread_position_in_grid.x;
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7) mx.random.seed(7)
a = mx.random.normal(shape=(2, 2)) a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel( kernel = custom_kernel(
name="basic", name="basic",
input_names=["a"], input_names=["a"],
output_names=["out1"], output_names=["out1"],
source=""" source=source,
uint elem = thread_position_in_grid.x;
out1[elem] = a[elem];
""",
) )
out = kernel( out = kernel(
inputs=[a], inputs=[a],
@ -604,16 +614,9 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(out[0], a)) self.assertTrue(mx.allclose(out[0], a))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_args(self): def test_custom_kernel_args(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(3, 6))
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
kernel = mx.fast.metal_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
T tmp = a[0]; T tmp = a[0];
@ -623,7 +626,30 @@ class TestFast(mlx_tests.MLXTestCase):
out1[elem] = 1; out1[elem] = 1;
} }
out2[elem] = a[1] + b[2] + c[1] - d; out2[elem] = a[1] + b[2] + c[1] - d;
""", """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = a[0];
if (e) {
out1[elem] = a[1] + b[2] + static_cast<float>(c[3]) + d[0] + f;
} else {
out1[elem] = 1;
}
out2[elem] = a[1] + b[2] + static_cast<float>(c[1]) - d[0];
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
kernel = custom_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source=source,
) )
out = kernel( out = kernel(
inputs=[ inputs=[
@ -647,10 +673,9 @@ class TestFast(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484))) self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32))) self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_strides(self): def test_custom_kernel_strides(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(3, 6))
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
@ -662,12 +687,29 @@ class TestFast(mlx_tests.MLXTestCase):
T tmp = inp[elem]; T tmp = inp[elem];
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
""" """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim);
T tmp = inp[loc];
out[elem] = exp(tmp) * WARP_SIZE;
"""
source_contig = """
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = inp[elem];
out[elem] = exp(tmp) * WARP_SIZE;
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
# non contiguous # non contiguous
a = mx.tile(a[::2], [4, 1]) a = mx.tile(a[::2], [4, 1])
for contig in [True, False]: for contig in [True, False]:
kernel = mx.fast.metal_kernel( kernel = custom_kernel(
name="myexp" + str(contig), name="myexp" + str(contig),
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
@ -685,24 +727,41 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0])) self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_helper(self): def test_custom_kernel_helper(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header = """ header = """
template <typename T> template <typename T>
T do_exp(T x) { T do_exp(T x) {
return metal::precise::exp(x); return metal::precise::exp(x);
} }
""", """
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
out1[elem] = do_exp(a[elem]); out1[elem] = do_exp(a[elem]);
""", """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
header = """
template <typename T>
__device__ T do_exp(T x) {
return exp(x);
}
"""
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = do_exp(a[elem]);
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = custom_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header=header,
source=source,
) )
out = kernel( out = kernel(
inputs=[a], inputs=[a],
@ -714,16 +773,21 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(out[0], mx.exp(a))) self.assertTrue(mx.allclose(out[0], mx.exp(a)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_attributes(self): def test_custom_kernel_attributes(self):
if mx.metal.is_available():
source = "out[0] = threads_per_threadgroup.x;"
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = "out[0] = blockDim.x;"
custom_kernel = mx.fast.cuda_kernel
a = mx.zeros(shape=(1, 1)) a = mx.zeros(shape=(1, 1))
kernel = mx.fast.metal_kernel( kernel = custom_kernel(
name="test_fun", name="test_fun",
input_names=["a"], input_names=["a"],
output_names=["out"], output_names=["out"],
source=""" source=source,
out[0] = threads_per_threadgroup.x;
""",
) )
out = kernel( out = kernel(
inputs=[a], inputs=[a],