diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 29beca859..f094fdfe8 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -17,6 +17,7 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp diff --git a/python/src/cuda.cpp b/python/src/cuda.cpp new file mode 100644 index 000000000..13b3a0154 --- /dev/null +++ b/python/src/cuda.cpp @@ -0,0 +1,19 @@ +// Copyright © 2023-2025 Apple Inc. + +#include + +#include "mlx/backend/cuda/cuda.h" + +namespace mx = mlx::core; +namespace nb = nanobind; + +void init_cuda(nb::module_& m) { + nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda"); + + cuda.def( + "is_available", + &mx::cu::is_available, + R"pbdoc( + Check if the CUDA back-end is available. + )pbdoc"); +} diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index eaddecb26..d89e48300 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -12,6 +12,7 @@ void init_array(nb::module_&); void init_device(nb::module_&); void init_stream(nb::module_&); void init_metal(nb::module_&); +void init_cuda(nb::module_&); void init_memory(nb::module_&); void init_ops(nb::module_&); void init_transforms(nb::module_&); @@ -35,6 +36,7 @@ NB_MODULE(core, m) { init_stream(m); init_array(m); init_metal(m); + init_cuda(m); init_memory(m); init_ops(m); init_transforms(m); diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index f79a62a15..518c26c70 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -581,18 +581,28 @@ class TestFast(mlx_tests.MLXTestCase): )(x) self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_basic(self): + if mx.metal.is_available(): + source = """ + uint elem thread_position_in_grid.x; + out1[elem] = a[elem]; + """ + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + source = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + out1[elem] = a[elem]; + """ + custom_kernel = mx.fast.cuda_kernel + mx.random.seed(7) a = mx.random.normal(shape=(2, 2)) - kernel = mx.fast.metal_kernel( + kernel = custom_kernel( name="basic", input_names=["a"], output_names=["out1"], - source=""" - uint elem = thread_position_in_grid.x; - out1[elem] = a[elem]; - """, + source=source, ) out = kernel( inputs=[a], @@ -604,17 +614,10 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(out[0], a)) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_args(self): - mx.random.seed(7) - a = mx.random.normal(shape=(3, 6)) - c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16) - - kernel = mx.fast.metal_kernel( - name="arg_test", - input_names=["a", "b", "c", "d"], - output_names=["out1", "out2"], - source=""" + if mx.metal.is_available(): + source = """ uint elem = thread_position_in_grid.x; T tmp = a[0]; if (e) { @@ -623,7 +626,30 @@ class TestFast(mlx_tests.MLXTestCase): out1[elem] = 1; } out2[elem] = a[1] + b[2] + c[1] - d; - """, + """ + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + source = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + T tmp = a[0]; + if (e) { + out1[elem] = a[1] + b[2] + static_cast(c[3]) + d[0] + f; + } else { + out1[elem] = 1; + } + out2[elem] = a[1] + b[2] + static_cast(c[1]) - d[0]; + """ + custom_kernel = mx.fast.cuda_kernel + + mx.random.seed(7) + a = mx.random.normal(shape=(3, 6)) + c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16) + + kernel = custom_kernel( + name="arg_test", + input_names=["a", "b", "c", "d"], + output_names=["out1", "out2"], + source=source, ) out = kernel( inputs=[ @@ -647,27 +673,43 @@ class TestFast(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484))) self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32))) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_strides(self): + if mx.metal.is_available(): + source = """ + uint elem = thread_position_in_grid.x; + uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); + T tmp = inp[loc]; + out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; + """ + source_contig = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; + """ + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + source = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim); + T tmp = inp[loc]; + out[elem] = exp(tmp) * WARP_SIZE; + """ + source_contig = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + T tmp = inp[elem]; + out[elem] = exp(tmp) * WARP_SIZE; + """ + custom_kernel = mx.fast.cuda_kernel + mx.random.seed(7) a = mx.random.normal(shape=(3, 6)) - source = """ - uint elem = thread_position_in_grid.x; - uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); - T tmp = inp[loc]; - out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; - """ - source_contig = """ - uint elem = thread_position_in_grid.x; - T tmp = inp[elem]; - out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; - """ # non contiguous a = mx.tile(a[::2], [4, 1]) for contig in [True, False]: - kernel = mx.fast.metal_kernel( + kernel = custom_kernel( name="myexp" + str(contig), input_names=["inp"], output_names=["out"], @@ -685,24 +727,41 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0])) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_helper(self): - mx.random.seed(7) - a = mx.random.normal(shape=(2, 2)) - kernel = mx.fast.metal_kernel( - name="helper", - input_names=["a"], - output_names=["out1"], - header=""" + if mx.metal.is_available(): + header = """ template T do_exp(T x) { return metal::precise::exp(x); } - """, - source=""" + """ + source = """ uint elem = thread_position_in_grid.x; out1[elem] = do_exp(a[elem]); - """, + """ + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + header = """ + template + __device__ T do_exp(T x) { + return exp(x); + } + """ + source = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + out1[elem] = do_exp(a[elem]); + """ + custom_kernel = mx.fast.cuda_kernel + + mx.random.seed(7) + a = mx.random.normal(shape=(2, 2)) + kernel = custom_kernel( + name="helper", + input_names=["a"], + output_names=["out1"], + header=header, + source=source, ) out = kernel( inputs=[a], @@ -714,16 +773,21 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(out[0], mx.exp(a))) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_attributes(self): + if mx.metal.is_available(): + source = "out[0] = threads_per_threadgroup.x;" + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + source = "out[0] = blockDim.x;" + custom_kernel = mx.fast.cuda_kernel + a = mx.zeros(shape=(1, 1)) - kernel = mx.fast.metal_kernel( + kernel = custom_kernel( name="test_fun", input_names=["a"], output_names=["out"], - source=""" - out[0] = threads_per_threadgroup.x; - """, + source=source, ) out = kernel( inputs=[a],