mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
A bit of refactoring
This commit is contained in:
@@ -17,6 +17,66 @@ namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
namespace {
|
||||
|
||||
struct PyCustomKernelFunction {
|
||||
PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag)
|
||||
: kernel_(std::move(kernel)), tag_(tag) {}
|
||||
|
||||
std::vector<mx::array> operator()(
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<mx::Shape>& output_shapes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::optional<std::vector<std::pair<std::string, nb::object>>>&
|
||||
template_args_ = std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
mx::StreamOrDevice s = {}) const {
|
||||
std::vector<mx::array> inputs;
|
||||
for (const auto& value : inputs_) {
|
||||
inputs.push_back(to_array(value, std::nullopt));
|
||||
}
|
||||
std::vector<std::pair<std::string, mx::fast::TemplateArg>> template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
// Handle bool, int and dtype template args
|
||||
if (nb::isinstance<bool>(value)) {
|
||||
bool bool_val = nb::cast<bool>(value);
|
||||
template_args.emplace_back(name, bool_val);
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.emplace_back(name, int_val);
|
||||
} else if (nb::isinstance<mx::Dtype>(value)) {
|
||||
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
|
||||
template_args.emplace_back(name, dtype);
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << tag_
|
||||
<< " Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel_(
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
}
|
||||
|
||||
mx::fast::CustomKernelFunction kernel_;
|
||||
const char* tag_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void init_fast(nb::module_& parent_module) {
|
||||
auto m =
|
||||
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
||||
@@ -240,53 +300,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
ensure_row_contiguous,
|
||||
atomic_outputs);
|
||||
return nb::cpp_function(
|
||||
[kernel = std::move(kernel)](
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<mx::Shape>& output_shapes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::optional<
|
||||
std::vector<std::pair<std::string, nb::object>>>&
|
||||
template_args_ = std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
std::vector<mx::array> inputs;
|
||||
for (const auto& value : inputs_) {
|
||||
inputs.push_back(to_array(value, std::nullopt));
|
||||
}
|
||||
std::vector<std::pair<std::string, mx::fast::TemplateArg>>
|
||||
template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
// Handle bool, int and dtype template args
|
||||
if (nb::isinstance<bool>(value)) {
|
||||
bool bool_val = nb::cast<bool>(value);
|
||||
template_args.emplace_back(name, bool_val);
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.emplace_back(name, int_val);
|
||||
} else if (nb::isinstance<mx::Dtype>(value)) {
|
||||
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
|
||||
template_args.emplace_back(name, dtype);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel(
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
},
|
||||
PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"),
|
||||
nb::kw_only(),
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
@@ -386,7 +400,123 @@ void init_fast(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"precompiled_custom_kernel",
|
||||
"cuda_kernel",
|
||||
[](const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header,
|
||||
bool ensure_row_contiguous,
|
||||
int shared_mem) {
|
||||
auto kernel = mx::fast::cuda_kernel(
|
||||
name,
|
||||
input_names,
|
||||
output_names,
|
||||
source,
|
||||
header,
|
||||
ensure_row_contiguous,
|
||||
shared_mem);
|
||||
return nb::cpp_function(
|
||||
PyCustomKernelFunction(std::move(kernel), "[cuda_kernel]"),
|
||||
nb::kw_only(),
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
"output_dtypes"_a,
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"template"_a = nb::none(),
|
||||
"init_value"_a = nb::none(),
|
||||
"verbose"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
|
||||
Args:
|
||||
inputs (List[array]): The inputs passed to the CUDA kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
List[array]: The list of output arrays.)pbdoc");
|
||||
},
|
||||
"name"_a,
|
||||
"input_names"_a,
|
||||
"output_names"_a,
|
||||
"source"_a,
|
||||
"header"_a = "",
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"shared_memory"_a = 0,
|
||||
R"pbdoc(
|
||||
A jit-compiled custom CUDA kernel defined from a source string.
|
||||
|
||||
This is the CUDA equivalent of :ref:`custom_metal_kernels`.
|
||||
|
||||
Args:
|
||||
name (str): Name for the kernel.
|
||||
input_names (List[str]): The parameter names of the inputs in the
|
||||
function signature.
|
||||
output_names (List[str]): The parameter names of the outputs in the
|
||||
function signature.
|
||||
source (str): Source code. This is the body of a function in CUDA,
|
||||
the function signature will be automatically generated.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of
|
||||
the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
shared_memory (int): The dynamic shared memory to request for the
|
||||
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
|
||||
|
||||
Returns:
|
||||
Callable ``cuda_kernel``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = '''
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
T tmp = inp[elem];
|
||||
out[elem] = exp(tmp);
|
||||
'''
|
||||
|
||||
kernel = mx.fast.cuda_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
verbose=True,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(16, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"precompiled_cuda_kernel",
|
||||
[](const std::string& name,
|
||||
const nb::bytes compiled_source,
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
@@ -420,14 +550,14 @@ void init_fast(nb::module_& parent_module) {
|
||||
std::string vtype_name =
|
||||
nb::cast<std::string>(vtype.attr("__name__"));
|
||||
std::ostringstream msg;
|
||||
msg << "[precompiled_custom_kernel] Invalid scalar argument type. "
|
||||
msg << "[precompiled_cuda_kernel] Invalid scalar argument type. "
|
||||
<< "Received " << vtype_name
|
||||
<< " but must be one of bool, int or float";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
return mx::fast::precompiled_custom_kernel(
|
||||
return mx::fast::precompiled_cuda_kernel(
|
||||
name,
|
||||
std::string(
|
||||
static_cast<const char*>(compiled_source.data()),
|
||||
@@ -457,5 +587,27 @@ void init_fast(nb::module_& parent_module) {
|
||||
"ensure_row_contiguous"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Run a precompiled CUDA kernel defined from PTX or cubin.
|
||||
|
||||
This op is still experimental and various parts of the API may change.
|
||||
|
||||
Args:
|
||||
name (str): Name for the kernel
|
||||
compiled_source (bytes): The precompiled kernel in raw bytes.
|
||||
inputs (List[array]): The inputs passed to the CUDA kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output.
|
||||
scalars (List[Union[bool, int, float]]): A list of scalar arguments to
|
||||
pass to the kernel.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
shared_memory (int): The dynamic shared memory to request for the
|
||||
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
Reference in New Issue
Block a user