mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Simplifications for MLX C (#1396)
* simplifications for MLX C * use vectors instead of map * update examples
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/map.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/tuple.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
@@ -193,39 +193,130 @@ void init_fast(nb::module_& parent_module) {
|
||||
array: The quantized version of ``w``
|
||||
)pbdoc");
|
||||
|
||||
nb::class_<fast::MetalKernel>(
|
||||
m,
|
||||
m.def(
|
||||
"metal_kernel",
|
||||
[](const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
auto kernel = fast::metal_kernel(
|
||||
name,
|
||||
input_names,
|
||||
output_names,
|
||||
source,
|
||||
header,
|
||||
ensure_row_contiguous,
|
||||
atomic_outputs);
|
||||
return nb::cpp_function(
|
||||
[kernel = std::move(kernel)](
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<std::vector<int>>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::optional<
|
||||
std::vector<std::pair<std::string, nb::object>>>&
|
||||
template_args_ = std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s = {}) {
|
||||
std::vector<array> inputs;
|
||||
for (const auto& value : inputs_) {
|
||||
inputs.push_back(to_array(value, std::nullopt));
|
||||
}
|
||||
std::vector<std::pair<std::string, fast::TemplateArg>>
|
||||
template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
// Handle bool, int and dtype template args
|
||||
if (nb::isinstance<bool>(value)) {
|
||||
bool bool_val = nb::cast<bool>(value);
|
||||
template_args.emplace_back(name, bool_val);
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.emplace_back(name, int_val);
|
||||
} else if (nb::isinstance<Dtype>(value)) {
|
||||
Dtype dtype = nb::cast<Dtype>(value);
|
||||
template_args.emplace_back(name, dtype);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel(
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
},
|
||||
nb::kw_only(),
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
"output_dtypes"_a,
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"template"_a = nb::none(),
|
||||
"init_value"_a = nb::none(),
|
||||
"verbose"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
|
||||
Args:
|
||||
inputs (List[array]): The inputs passed to the Metal kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
List[array]: The list of output arrays.
|
||||
)pbdoc");
|
||||
},
|
||||
"name"_a,
|
||||
"input_names"_a,
|
||||
"output_names"_a,
|
||||
"source"_a,
|
||||
"header"_a = "",
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"atomic_outputs"_a = false,
|
||||
R"pbdoc(
|
||||
A jit-compiled custom Metal kernel defined from a source string.
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool,
|
||||
bool>(),
|
||||
"name"_a,
|
||||
"source"_a,
|
||||
"header"_a = "",
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"atomic_outputs"_a = false,
|
||||
R"pbdoc(
|
||||
Initialize a metal_kernel.
|
||||
|
||||
Args:
|
||||
name (str): Name for the kernel.
|
||||
source (str): Source code. This is the body of a function in Metal,
|
||||
the function signature will be generated for you. The names of the inputs/outputs
|
||||
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
|
||||
used when the kernel is called.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
name (str): Name for the kernel.
|
||||
input_names (List[str]): The parameter names of the inputs in the
|
||||
function signature.
|
||||
output_names (List[str]): The parameter names of the outputs in the
|
||||
function signature.
|
||||
source (str): Source code. This is the body of a function in Metal,
|
||||
the function signature will be automatically generated.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of
|
||||
the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
|
||||
Returns:
|
||||
Callable ``metal_kernel``.
|
||||
|
||||
@@ -242,103 +333,23 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
verbose=True,
|
||||
)
|
||||
return outputs["out"]
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__call__",
|
||||
[](fast::MetalKernel& kernel,
|
||||
std::map<std::string, ScalarOrArray>& inputs_,
|
||||
std::map<std::string, std::vector<int>>& output_shapes,
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, nb::handle>> template_args_,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s) {
|
||||
std::map<std::string, array> inputs;
|
||||
for (const auto& [name, value] : inputs_) {
|
||||
auto arr = to_array(value, std::nullopt);
|
||||
inputs.insert({name, arr});
|
||||
}
|
||||
std::map<std::string, fast::TemplateArg> template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
// Handle bool, int and dtype template args
|
||||
if (nb::isinstance<bool>(value)) {
|
||||
bool bool_val = nb::cast<bool>(value);
|
||||
template_args.insert({name, bool_val});
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.insert({name, int_val});
|
||||
} else if (nb::isinstance<Dtype>(value)) {
|
||||
Dtype dtype = nb::cast<Dtype>(value);
|
||||
template_args.insert({name, dtype});
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel(
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
},
|
||||
nb::kw_only(),
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
"output_dtypes"_a,
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"template"_a = nb::none(),
|
||||
"init_value"_a = nb::none(),
|
||||
"verbose"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
|
||||
Args:
|
||||
inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
|
||||
The keys will be the names of the arguments to the kernel.
|
||||
output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
|
||||
output variable names to shapes. These will be added to the function signature.
|
||||
output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
|
||||
names to dtypes. Must have the same keys as ``output_shapes``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
|
||||
)pbdoc");
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -562,20 +562,22 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="basic",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={"a": a},
|
||||
inputs=[a],
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2)},
|
||||
output_dtypes={"out1": mx.float32},
|
||||
output_shapes=[(2, 2)],
|
||||
output_dtypes=[mx.float32],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(out["out1"], a))
|
||||
self.assertTrue(mx.allclose(out[0], a))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_args(self):
|
||||
@@ -585,6 +587,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="arg_test",
|
||||
input_names=["a", "b", "c", "d"],
|
||||
output_names=["out1", "out2"],
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = a[0];
|
||||
@@ -597,26 +601,26 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={
|
||||
"a": a,
|
||||
"b": mx.array([3, 4, 5]),
|
||||
"c": c,
|
||||
"d": 7.3,
|
||||
},
|
||||
template={
|
||||
"e": True,
|
||||
"f": 3,
|
||||
"T": mx.float16,
|
||||
},
|
||||
inputs=[
|
||||
a,
|
||||
mx.array([3, 4, 5]),
|
||||
c,
|
||||
7.3,
|
||||
],
|
||||
template=[
|
||||
("e", True),
|
||||
("f", 3),
|
||||
("T", mx.float16),
|
||||
],
|
||||
grid=(6, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2), "out2": (3, 2)},
|
||||
output_dtypes={"out1": mx.float32, "out2": mx.int32},
|
||||
output_shapes=[(2, 2), (3, 2)],
|
||||
output_dtypes=[mx.float32, mx.int32],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484)))
|
||||
self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||
self.assertTrue(mx.allclose(out[0], mx.full((2, 2), 14.0484)))
|
||||
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_strides(self):
|
||||
@@ -640,19 +644,21 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
for contig in [True, False]:
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp" + str(contig),
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source_contig if contig else source,
|
||||
ensure_row_contiguous=contig,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_helper(self):
|
||||
@@ -660,6 +666,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="helper",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
header="""
|
||||
template <typename T>
|
||||
T do_exp(T x) {
|
||||
@@ -672,14 +680,14 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={"a": a},
|
||||
inputs=[a],
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2)},
|
||||
output_dtypes={"out1": mx.float32},
|
||||
output_shapes=[(2, 2)],
|
||||
output_dtypes=[mx.float32],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
|
||||
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user