mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 07:58:41 +08:00
Add all tests except the custom caching
This commit is contained in:
parent
14efd9c35a
commit
bffadc2cb9
@ -17,6 +17,7 @@ nanobind_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
|
19
python/src/cuda.cpp
Normal file
19
python/src/cuda.cpp
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <nanobind/nanobind.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
|
namespace nb = nanobind;
|
||||||
|
|
||||||
|
void init_cuda(nb::module_& m) {
|
||||||
|
nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda");
|
||||||
|
|
||||||
|
cuda.def(
|
||||||
|
"is_available",
|
||||||
|
&mx::cu::is_available,
|
||||||
|
R"pbdoc(
|
||||||
|
Check if the CUDA back-end is available.
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -12,6 +12,7 @@ void init_array(nb::module_&);
|
|||||||
void init_device(nb::module_&);
|
void init_device(nb::module_&);
|
||||||
void init_stream(nb::module_&);
|
void init_stream(nb::module_&);
|
||||||
void init_metal(nb::module_&);
|
void init_metal(nb::module_&);
|
||||||
|
void init_cuda(nb::module_&);
|
||||||
void init_memory(nb::module_&);
|
void init_memory(nb::module_&);
|
||||||
void init_ops(nb::module_&);
|
void init_ops(nb::module_&);
|
||||||
void init_transforms(nb::module_&);
|
void init_transforms(nb::module_&);
|
||||||
@ -35,6 +36,7 @@ NB_MODULE(core, m) {
|
|||||||
init_stream(m);
|
init_stream(m);
|
||||||
init_array(m);
|
init_array(m);
|
||||||
init_metal(m);
|
init_metal(m);
|
||||||
|
init_cuda(m);
|
||||||
init_memory(m);
|
init_memory(m);
|
||||||
init_ops(m);
|
init_ops(m);
|
||||||
init_transforms(m);
|
init_transforms(m);
|
||||||
|
@ -581,18 +581,28 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)(x)
|
)(x)
|
||||||
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||||
def test_custom_kernel_basic(self):
|
def test_custom_kernel_basic(self):
|
||||||
|
if mx.metal.is_available():
|
||||||
|
source = """
|
||||||
|
uint elem thread_position_in_grid.x;
|
||||||
|
out1[elem] = a[elem];
|
||||||
|
"""
|
||||||
|
custom_kernel = mx.fast.metal_kernel
|
||||||
|
elif mx.cuda.is_available():
|
||||||
|
source = """
|
||||||
|
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||||
|
out1[elem] = a[elem];
|
||||||
|
"""
|
||||||
|
custom_kernel = mx.fast.cuda_kernel
|
||||||
|
|
||||||
mx.random.seed(7)
|
mx.random.seed(7)
|
||||||
a = mx.random.normal(shape=(2, 2))
|
a = mx.random.normal(shape=(2, 2))
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = custom_kernel(
|
||||||
name="basic",
|
name="basic",
|
||||||
input_names=["a"],
|
input_names=["a"],
|
||||||
output_names=["out1"],
|
output_names=["out1"],
|
||||||
source="""
|
source=source,
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
out1[elem] = a[elem];
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
@ -604,17 +614,10 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(out[0], a))
|
self.assertTrue(mx.allclose(out[0], a))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||||
def test_custom_kernel_args(self):
|
def test_custom_kernel_args(self):
|
||||||
mx.random.seed(7)
|
if mx.metal.is_available():
|
||||||
a = mx.random.normal(shape=(3, 6))
|
source = """
|
||||||
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="arg_test",
|
|
||||||
input_names=["a", "b", "c", "d"],
|
|
||||||
output_names=["out1", "out2"],
|
|
||||||
source="""
|
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
T tmp = a[0];
|
T tmp = a[0];
|
||||||
if (e) {
|
if (e) {
|
||||||
@ -623,7 +626,30 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
out1[elem] = 1;
|
out1[elem] = 1;
|
||||||
}
|
}
|
||||||
out2[elem] = a[1] + b[2] + c[1] - d;
|
out2[elem] = a[1] + b[2] + c[1] - d;
|
||||||
""",
|
"""
|
||||||
|
custom_kernel = mx.fast.metal_kernel
|
||||||
|
elif mx.cuda.is_available():
|
||||||
|
source = """
|
||||||
|
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||||
|
T tmp = a[0];
|
||||||
|
if (e) {
|
||||||
|
out1[elem] = a[1] + b[2] + static_cast<float>(c[3]) + d[0] + f;
|
||||||
|
} else {
|
||||||
|
out1[elem] = 1;
|
||||||
|
}
|
||||||
|
out2[elem] = a[1] + b[2] + static_cast<float>(c[1]) - d[0];
|
||||||
|
"""
|
||||||
|
custom_kernel = mx.fast.cuda_kernel
|
||||||
|
|
||||||
|
mx.random.seed(7)
|
||||||
|
a = mx.random.normal(shape=(3, 6))
|
||||||
|
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
|
||||||
|
|
||||||
|
kernel = custom_kernel(
|
||||||
|
name="arg_test",
|
||||||
|
input_names=["a", "b", "c", "d"],
|
||||||
|
output_names=["out1", "out2"],
|
||||||
|
source=source,
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -647,27 +673,43 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
|
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
|
||||||
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||||
def test_custom_kernel_strides(self):
|
def test_custom_kernel_strides(self):
|
||||||
|
if mx.metal.is_available():
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
|
T tmp = inp[loc];
|
||||||
|
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||||
|
"""
|
||||||
|
source_contig = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
T tmp = inp[elem];
|
||||||
|
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||||
|
"""
|
||||||
|
custom_kernel = mx.fast.metal_kernel
|
||||||
|
elif mx.cuda.is_available():
|
||||||
|
source = """
|
||||||
|
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||||
|
auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim);
|
||||||
|
T tmp = inp[loc];
|
||||||
|
out[elem] = exp(tmp) * WARP_SIZE;
|
||||||
|
"""
|
||||||
|
source_contig = """
|
||||||
|
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||||
|
T tmp = inp[elem];
|
||||||
|
out[elem] = exp(tmp) * WARP_SIZE;
|
||||||
|
"""
|
||||||
|
custom_kernel = mx.fast.cuda_kernel
|
||||||
|
|
||||||
mx.random.seed(7)
|
mx.random.seed(7)
|
||||||
a = mx.random.normal(shape=(3, 6))
|
a = mx.random.normal(shape=(3, 6))
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
|
||||||
T tmp = inp[loc];
|
|
||||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
|
||||||
"""
|
|
||||||
source_contig = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
T tmp = inp[elem];
|
|
||||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
|
||||||
"""
|
|
||||||
|
|
||||||
# non contiguous
|
# non contiguous
|
||||||
a = mx.tile(a[::2], [4, 1])
|
a = mx.tile(a[::2], [4, 1])
|
||||||
|
|
||||||
for contig in [True, False]:
|
for contig in [True, False]:
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = custom_kernel(
|
||||||
name="myexp" + str(contig),
|
name="myexp" + str(contig),
|
||||||
input_names=["inp"],
|
input_names=["inp"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
@ -685,24 +727,41 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||||
def test_custom_kernel_helper(self):
|
def test_custom_kernel_helper(self):
|
||||||
mx.random.seed(7)
|
if mx.metal.is_available():
|
||||||
a = mx.random.normal(shape=(2, 2))
|
header = """
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="helper",
|
|
||||||
input_names=["a"],
|
|
||||||
output_names=["out1"],
|
|
||||||
header="""
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T do_exp(T x) {
|
T do_exp(T x) {
|
||||||
return metal::precise::exp(x);
|
return metal::precise::exp(x);
|
||||||
}
|
}
|
||||||
""",
|
"""
|
||||||
source="""
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
out1[elem] = do_exp(a[elem]);
|
out1[elem] = do_exp(a[elem]);
|
||||||
""",
|
"""
|
||||||
|
custom_kernel = mx.fast.metal_kernel
|
||||||
|
elif mx.cuda.is_available():
|
||||||
|
header = """
|
||||||
|
template <typename T>
|
||||||
|
__device__ T do_exp(T x) {
|
||||||
|
return exp(x);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
source = """
|
||||||
|
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||||
|
out1[elem] = do_exp(a[elem]);
|
||||||
|
"""
|
||||||
|
custom_kernel = mx.fast.cuda_kernel
|
||||||
|
|
||||||
|
mx.random.seed(7)
|
||||||
|
a = mx.random.normal(shape=(2, 2))
|
||||||
|
kernel = custom_kernel(
|
||||||
|
name="helper",
|
||||||
|
input_names=["a"],
|
||||||
|
output_names=["out1"],
|
||||||
|
header=header,
|
||||||
|
source=source,
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
@ -714,16 +773,21 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||||
def test_custom_kernel_attributes(self):
|
def test_custom_kernel_attributes(self):
|
||||||
|
if mx.metal.is_available():
|
||||||
|
source = "out[0] = threads_per_threadgroup.x;"
|
||||||
|
custom_kernel = mx.fast.metal_kernel
|
||||||
|
elif mx.cuda.is_available():
|
||||||
|
source = "out[0] = blockDim.x;"
|
||||||
|
custom_kernel = mx.fast.cuda_kernel
|
||||||
|
|
||||||
a = mx.zeros(shape=(1, 1))
|
a = mx.zeros(shape=(1, 1))
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = custom_kernel(
|
||||||
name="test_fun",
|
name="test_fun",
|
||||||
input_names=["a"],
|
input_names=["a"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source="""
|
source=source,
|
||||||
out[0] = threads_per_threadgroup.x;
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
|
Loading…
Reference in New Issue
Block a user