Working custom kernels jointly

This commit is contained in:
Angelos Katharopoulos
2025-08-12 14:30:29 -07:00
parent 0b309e8edc
commit 3b94e37270
8 changed files with 425 additions and 15 deletions

View File

@@ -384,4 +384,76 @@ void init_fast(nb::module_& parent_module) {
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
)pbdoc");
m.def(
"precompiled_custom_kernel",
[](const std::string& name,
const std::string& compiled_source,
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
const std::vector<nb::object>& scalars_,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value = std::nullopt,
bool ensure_row_contiguous = false,
mx::StreamOrDevice s = {}) {
// Collect the inputs and cast them to array
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
// Collect the scalar inputs
std::vector<mx::fast::ScalarArg> scalars;
scalars.reserve(scalars_.size());
for (const auto& v : scalars_) {
if (nb::isinstance<bool>(v)) {
scalars.push_back(nb::cast<bool>(v));
} else if (nb::isinstance<int>(v)) {
scalars.push_back(nb::cast<int>(v));
} else if (nb::isinstance<float>(v)) {
scalars.push_back(nb::cast<float>(v));
} else {
nb::object vtype = v.attr("__class__");
std::string vtype_name =
nb::cast<std::string>(vtype.attr("__name__"));
std::ostringstream msg;
msg << "[precompiled_custom_kernel] Invalid scalar argument type. "
<< "Received " << vtype_name
<< " but must be one of bool, int or float";
throw std::invalid_argument(msg.str());
}
}
return mx::fast::precompiled_custom_kernel(
name,
compiled_source,
inputs,
output_shapes,
output_dtypes,
scalars,
grid,
threadgroup,
shared_memory,
init_value,
ensure_row_contiguous,
s);
},
nb::kw_only(),
"name"_a,
"compiled_source"_a,
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"scalars"_a,
"grid"_a,
"threadgroup"_a,
"shared_memory"_a = 0,
"init_value"_a = nb::none(),
"ensure_row_contiguous"_a = false,
"stream"_a = nb::none(),
R"pbdoc(
)pbdoc");
}