Simplifications for MLX C (#1396)

* simplifications for MLX C

* use vectors instead of map

* update examples
This commit is contained in:
Awni Hannun
2024-09-06 19:16:50 -07:00
committed by GitHub
parent 7cca1727af
commit ba3e913c7a
7 changed files with 334 additions and 331 deletions

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/variant.h>
@@ -193,39 +193,130 @@ void init_fast(nb::module_& parent_module) {
array: The quantized version of ``w``
)pbdoc");
nb::class_<fast::MetalKernel>(
m,
m.def(
"metal_kernel",
[](const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
bool atomic_outputs) {
auto kernel = fast::metal_kernel(
name,
input_names,
output_names,
source,
header,
ensure_row_contiguous,
atomic_outputs);
return nb::cpp_function(
[kernel = std::move(kernel)](
const std::vector<ScalarOrArray>& inputs_,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::optional<
std::vector<std::pair<std::string, nb::object>>>&
template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s = {}) {
std::vector<array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
std::vector<std::pair<std::string, fast::TemplateArg>>
template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.emplace_back(name, bool_val);
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val);
} else if (nb::isinstance<Dtype>(value)) {
Dtype dtype = nb::cast<Dtype>(value);
template_args.emplace_back(name, dtype);
} else {
throw std::invalid_argument(
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
return kernel(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
},
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"init_value"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Args:
inputs (List[array]): The inputs passed to the Metal kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
List[array]: The list of output arrays.
)pbdoc");
},
"name"_a,
"input_names"_a,
"output_names"_a,
"source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true,
"atomic_outputs"_a = false,
R"pbdoc(
A jit-compiled custom Metal kernel defined from a source string.
)pbdoc")
.def(
nb::init<
const std::string&,
const std::string&,
const std::string&,
bool,
bool>(),
"name"_a,
"source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true,
"atomic_outputs"_a = false,
R"pbdoc(
Initialize a metal_kernel.
Args:
name (str): Name for the kernel.
source (str): Source code. This is the body of a function in Metal,
the function signature will be generated for you. The names of the inputs/outputs
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
used when the kernel is called.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
name (str): Name for the kernel.
input_names (List[str]): The parameter names of the inputs in the
function signature.
output_names (List[str]): The parameter names of the outputs in the
function signature.
source (str): Source code. This is the body of a function in Metal,
the function signature will be automatically generated.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of
the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
Returns:
Callable ``metal_kernel``.
@@ -242,103 +333,23 @@ void init_fast(nb::module_& parent_module) {
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
output_shapes=[a.shape],
output_dtypes=[a.dtype],
verbose=True,
)
return outputs["out"]
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
)pbdoc")
.def(
"__call__",
[](fast::MetalKernel& kernel,
std::map<std::string, ScalarOrArray>& inputs_,
std::map<std::string, std::vector<int>>& output_shapes,
std::map<std::string, Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, nb::handle>> template_args_,
std::optional<float> init_value,
bool verbose,
StreamOrDevice s) {
std::map<std::string, array> inputs;
for (const auto& [name, value] : inputs_) {
auto arr = to_array(value, std::nullopt);
inputs.insert({name, arr});
}
std::map<std::string, fast::TemplateArg> template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.insert({name, bool_val});
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.insert({name, int_val});
} else if (nb::isinstance<Dtype>(value)) {
Dtype dtype = nb::cast<Dtype>(value);
template_args.insert({name, dtype});
} else {
throw std::invalid_argument(
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
return kernel(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
},
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"init_value"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Args:
inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
The keys will be the names of the arguments to the kernel.
output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
output variable names to shapes. These will be added to the function signature.
output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
names to dtypes. Must have the same keys as ``output_shapes``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
)pbdoc");
)pbdoc");
}

View File

@@ -562,20 +562,22 @@ class TestFast(mlx_tests.MLXTestCase):
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="basic",
input_names=["a"],
output_names=["out1"],
source="""
uint elem = thread_position_in_grid.x;
out1[elem] = a[elem];
""",
)
out = kernel(
inputs={"a": a},
inputs=[a],
grid=(4, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2)},
output_dtypes={"out1": mx.float32},
output_shapes=[(2, 2)],
output_dtypes=[mx.float32],
stream=mx.gpu,
)
self.assertTrue(mx.allclose(out["out1"], a))
self.assertTrue(mx.allclose(out[0], a))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_args(self):
@@ -585,6 +587,8 @@ class TestFast(mlx_tests.MLXTestCase):
kernel = mx.fast.metal_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source="""
uint elem = thread_position_in_grid.x;
T tmp = a[0];
@@ -597,26 +601,26 @@ class TestFast(mlx_tests.MLXTestCase):
""",
)
out = kernel(
inputs={
"a": a,
"b": mx.array([3, 4, 5]),
"c": c,
"d": 7.3,
},
template={
"e": True,
"f": 3,
"T": mx.float16,
},
inputs=[
a,
mx.array([3, 4, 5]),
c,
7.3,
],
template=[
("e", True),
("f", 3),
("T", mx.float16),
],
grid=(6, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2), "out2": (3, 2)},
output_dtypes={"out1": mx.float32, "out2": mx.int32},
output_shapes=[(2, 2), (3, 2)],
output_dtypes=[mx.float32, mx.int32],
stream=mx.gpu,
)
self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484)))
self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32)))
self.assertTrue(mx.allclose(out[0], mx.full((2, 2), 14.0484)))
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_strides(self):
@@ -640,19 +644,21 @@ class TestFast(mlx_tests.MLXTestCase):
for contig in [True, False]:
kernel = mx.fast.metal_kernel(
name="myexp" + str(contig),
input_names=["inp"],
output_names=["out"],
source=source_contig if contig else source,
ensure_row_contiguous=contig,
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
output_shapes=[a.shape],
output_dtypes=[a.dtype],
stream=mx.gpu,
)
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_helper(self):
@@ -660,6 +666,8 @@ class TestFast(mlx_tests.MLXTestCase):
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header="""
template <typename T>
T do_exp(T x) {
@@ -672,14 +680,14 @@ class TestFast(mlx_tests.MLXTestCase):
""",
)
out = kernel(
inputs={"a": a},
inputs=[a],
grid=(4, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2)},
output_dtypes={"out1": mx.float32},
output_shapes=[(2, 2)],
output_dtypes=[mx.float32],
stream=mx.gpu,
)
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
if __name__ == "__main__":