mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-27 08:46:41 +08:00
Add all tests except the custom caching
This commit is contained in:
parent
14efd9c35a
commit
bffadc2cb9
@ -17,6 +17,7 @@ nanobind_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
|
19
python/src/cuda.cpp
Normal file
19
python/src/cuda.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
void init_cuda(nb::module_& m) {
|
||||
nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda");
|
||||
|
||||
cuda.def(
|
||||
"is_available",
|
||||
&mx::cu::is_available,
|
||||
R"pbdoc(
|
||||
Check if the CUDA back-end is available.
|
||||
)pbdoc");
|
||||
}
|
@ -12,6 +12,7 @@ void init_array(nb::module_&);
|
||||
void init_device(nb::module_&);
|
||||
void init_stream(nb::module_&);
|
||||
void init_metal(nb::module_&);
|
||||
void init_cuda(nb::module_&);
|
||||
void init_memory(nb::module_&);
|
||||
void init_ops(nb::module_&);
|
||||
void init_transforms(nb::module_&);
|
||||
@ -35,6 +36,7 @@ NB_MODULE(core, m) {
|
||||
init_stream(m);
|
||||
init_array(m);
|
||||
init_metal(m);
|
||||
init_cuda(m);
|
||||
init_memory(m);
|
||||
init_ops(m);
|
||||
init_transforms(m);
|
||||
|
@ -581,18 +581,28 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)(x)
|
||||
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_basic(self):
|
||||
if mx.metal.is_available():
|
||||
source = """
|
||||
uint elem thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
out1[elem] = a[elem];
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
kernel = custom_kernel(
|
||||
name="basic",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
""",
|
||||
source=source,
|
||||
)
|
||||
out = kernel(
|
||||
inputs=[a],
|
||||
@ -604,17 +614,10 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(out[0], a))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_args(self):
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="arg_test",
|
||||
input_names=["a", "b", "c", "d"],
|
||||
output_names=["out1", "out2"],
|
||||
source="""
|
||||
if mx.metal.is_available():
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = a[0];
|
||||
if (e) {
|
||||
@ -623,7 +626,30 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
out1[elem] = 1;
|
||||
}
|
||||
out2[elem] = a[1] + b[2] + c[1] - d;
|
||||
""",
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
T tmp = a[0];
|
||||
if (e) {
|
||||
out1[elem] = a[1] + b[2] + static_cast<float>(c[3]) + d[0] + f;
|
||||
} else {
|
||||
out1[elem] = 1;
|
||||
}
|
||||
out2[elem] = a[1] + b[2] + static_cast<float>(c[1]) - d[0];
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
|
||||
|
||||
kernel = custom_kernel(
|
||||
name="arg_test",
|
||||
input_names=["a", "b", "c", "d"],
|
||||
output_names=["out1", "out2"],
|
||||
source=source,
|
||||
)
|
||||
out = kernel(
|
||||
inputs=[
|
||||
@ -647,27 +673,43 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
|
||||
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_strides(self):
|
||||
if mx.metal.is_available():
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
source_contig = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
out[elem] = exp(tmp) * WARP_SIZE;
|
||||
"""
|
||||
source_contig = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
T tmp = inp[elem];
|
||||
out[elem] = exp(tmp) * WARP_SIZE;
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
source_contig = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
|
||||
# non contiguous
|
||||
a = mx.tile(a[::2], [4, 1])
|
||||
|
||||
for contig in [True, False]:
|
||||
kernel = mx.fast.metal_kernel(
|
||||
kernel = custom_kernel(
|
||||
name="myexp" + str(contig),
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
@ -685,24 +727,41 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_helper(self):
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="helper",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
header="""
|
||||
if mx.metal.is_available():
|
||||
header = """
|
||||
template <typename T>
|
||||
T do_exp(T x) {
|
||||
return metal::precise::exp(x);
|
||||
}
|
||||
""",
|
||||
source="""
|
||||
"""
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = do_exp(a[elem]);
|
||||
""",
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
header = """
|
||||
template <typename T>
|
||||
__device__ T do_exp(T x) {
|
||||
return exp(x);
|
||||
}
|
||||
"""
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
out1[elem] = do_exp(a[elem]);
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = custom_kernel(
|
||||
name="helper",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
header=header,
|
||||
source=source,
|
||||
)
|
||||
out = kernel(
|
||||
inputs=[a],
|
||||
@ -714,16 +773,21 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_attributes(self):
|
||||
if mx.metal.is_available():
|
||||
source = "out[0] = threads_per_threadgroup.x;"
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = "out[0] = blockDim.x;"
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
a = mx.zeros(shape=(1, 1))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
kernel = custom_kernel(
|
||||
name="test_fun",
|
||||
input_names=["a"],
|
||||
output_names=["out"],
|
||||
source="""
|
||||
out[0] = threads_per_threadgroup.x;
|
||||
""",
|
||||
source=source,
|
||||
)
|
||||
out = kernel(
|
||||
inputs=[a],
|
||||
|
Loading…
Reference in New Issue
Block a user