A bit of refactoring

This commit is contained in:
Angelos Katharopoulos
2025-08-18 19:18:41 -07:00
parent 3938aaaf24
commit d2ae81b413
9 changed files with 408 additions and 206 deletions

View File

@@ -17,6 +17,66 @@ namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
namespace {
struct PyCustomKernelFunction {
PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag)
: kernel_(std::move(kernel)), tag_(tag) {}
std::vector<mx::array> operator()(
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::optional<std::vector<std::pair<std::string, nb::object>>>&
template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
mx::StreamOrDevice s = {}) const {
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
std::vector<std::pair<std::string, mx::fast::TemplateArg>> template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.emplace_back(name, bool_val);
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val);
} else if (nb::isinstance<mx::Dtype>(value)) {
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
template_args.emplace_back(name, dtype);
} else {
std::ostringstream msg;
msg << tag_
<< " Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.";
throw std::invalid_argument(msg.str());
}
}
}
return kernel_(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
}
mx::fast::CustomKernelFunction kernel_;
const char* tag_;
};
} // namespace
void init_fast(nb::module_& parent_module) {
auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
@@ -240,53 +300,7 @@ void init_fast(nb::module_& parent_module) {
ensure_row_contiguous,
atomic_outputs);
return nb::cpp_function(
[kernel = std::move(kernel)](
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::optional<
std::vector<std::pair<std::string, nb::object>>>&
template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
mx::StreamOrDevice s = {}) {
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
std::vector<std::pair<std::string, mx::fast::TemplateArg>>
template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.emplace_back(name, bool_val);
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val);
} else if (nb::isinstance<mx::Dtype>(value)) {
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
template_args.emplace_back(name, dtype);
} else {
throw std::invalid_argument(
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
return kernel(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
},
PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"),
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
@@ -386,7 +400,123 @@ void init_fast(nb::module_& parent_module) {
)pbdoc");
m.def(
"precompiled_custom_kernel",
"cuda_kernel",
[](const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
int shared_mem) {
auto kernel = mx::fast::cuda_kernel(
name,
input_names,
output_names,
source,
header,
ensure_row_contiguous,
shared_mem);
return nb::cpp_function(
PyCustomKernelFunction(std::move(kernel), "[cuda_kernel]"),
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"init_value"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Args:
inputs (List[array]): The inputs passed to the CUDA kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
List[array]: The list of output arrays.)pbdoc");
},
"name"_a,
"input_names"_a,
"output_names"_a,
"source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true,
"shared_memory"_a = 0,
R"pbdoc(
A jit-compiled custom CUDA kernel defined from a source string.
This is the CUDA equivalent of :ref:`custom_metal_kernels`.
Args:
name (str): Name for the kernel.
input_names (List[str]): The parameter names of the inputs in the
function signature.
output_names (List[str]): The parameter names of the outputs in the
function signature.
source (str): Source code. This is the body of a function in CUDA,
the function signature will be automatically generated.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of
the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
shared_memory (int): The dynamic shared memory to request for the
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
Returns:
Callable ``cuda_kernel``.
Example:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = '''
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = inp[elem];
out[elem] = exp(tmp);
'''
kernel = mx.fast.cuda_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
verbose=True,
)
return outputs[0]
a = mx.random.normal(shape=(16, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
)pbdoc");
m.def(
"precompiled_cuda_kernel",
[](const std::string& name,
const nb::bytes compiled_source,
const std::vector<ScalarOrArray>& inputs_,
@@ -420,14 +550,14 @@ void init_fast(nb::module_& parent_module) {
std::string vtype_name =
nb::cast<std::string>(vtype.attr("__name__"));
std::ostringstream msg;
msg << "[precompiled_custom_kernel] Invalid scalar argument type. "
msg << "[precompiled_cuda_kernel] Invalid scalar argument type. "
<< "Received " << vtype_name
<< " but must be one of bool, int or float";
throw std::invalid_argument(msg.str());
}
}
return mx::fast::precompiled_custom_kernel(
return mx::fast::precompiled_cuda_kernel(
name,
std::string(
static_cast<const char*>(compiled_source.data()),
@@ -457,5 +587,27 @@ void init_fast(nb::module_& parent_module) {
"ensure_row_contiguous"_a = false,
"stream"_a = nb::none(),
R"pbdoc(
Run a precompiled CUDA kernel defined from PTX or cubin.
This op is still experimental and various parts of the API may change.
Args:
name (str): Name for the kernel
compiled_source (bytes): The precompiled kernel in raw bytes.
inputs (List[array]): The inputs passed to the CUDA kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output.
output_dtypes (List[Dtype]): The list of data types for each output.
scalars (List[Union[bool, int, float]]): A list of scalar arguments to
pass to the kernel.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
shared_memory (int): The dynamic shared memory to request for the
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
)pbdoc");
}