mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Working custom kernels jointly
This commit is contained in:
@@ -384,4 +384,76 @@ void init_fast(nb::module_& parent_module) {
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"precompiled_custom_kernel",
|
||||
[](const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<mx::Shape>& output_shapes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
const std::vector<nb::object>& scalars_,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool ensure_row_contiguous = false,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
// Collect the inputs and cast them to array
|
||||
std::vector<mx::array> inputs;
|
||||
for (const auto& value : inputs_) {
|
||||
inputs.push_back(to_array(value, std::nullopt));
|
||||
}
|
||||
|
||||
// Collect the scalar inputs
|
||||
std::vector<mx::fast::ScalarArg> scalars;
|
||||
scalars.reserve(scalars_.size());
|
||||
for (const auto& v : scalars_) {
|
||||
if (nb::isinstance<bool>(v)) {
|
||||
scalars.push_back(nb::cast<bool>(v));
|
||||
} else if (nb::isinstance<int>(v)) {
|
||||
scalars.push_back(nb::cast<int>(v));
|
||||
} else if (nb::isinstance<float>(v)) {
|
||||
scalars.push_back(nb::cast<float>(v));
|
||||
} else {
|
||||
nb::object vtype = v.attr("__class__");
|
||||
std::string vtype_name =
|
||||
nb::cast<std::string>(vtype.attr("__name__"));
|
||||
std::ostringstream msg;
|
||||
msg << "[precompiled_custom_kernel] Invalid scalar argument type. "
|
||||
<< "Received " << vtype_name
|
||||
<< " but must be one of bool, int or float";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
return mx::fast::precompiled_custom_kernel(
|
||||
name,
|
||||
compiled_source,
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
scalars,
|
||||
grid,
|
||||
threadgroup,
|
||||
shared_memory,
|
||||
init_value,
|
||||
ensure_row_contiguous,
|
||||
s);
|
||||
},
|
||||
nb::kw_only(),
|
||||
"name"_a,
|
||||
"compiled_source"_a,
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
"output_dtypes"_a,
|
||||
"scalars"_a,
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"shared_memory"_a = 0,
|
||||
"init_value"_a = nb::none(),
|
||||
"ensure_row_contiguous"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user