mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Simplifications for MLX C (#1396)
* simplifications for MLX C * use vectors instead of map * update examples
This commit is contained in:
parent
7cca1727af
commit
ba3e913c7a
@ -19,17 +19,19 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="myexp",
|
name="myexp",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs={"inp": a},
|
inputs=[a],
|
||||||
template={"T": mx.float32},
|
template=[("T", mx.float32)],
|
||||||
grid=(a.size, 1, 1),
|
grid=(a.size, 1, 1),
|
||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
output_shapes={"out": a.shape},
|
output_shapes=[a.shape],
|
||||||
output_dtypes={"out": a.dtype},
|
output_dtypes=[a.dtype],
|
||||||
)
|
)
|
||||||
return outputs["out"]
|
return outputs[0]
|
||||||
|
|
||||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
@ -40,16 +42,16 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
|
|
||||||
The full function signature will be generated using:
|
The full function signature will be generated using:
|
||||||
|
|
||||||
* The keys and shapes/dtypes of ``inputs``
|
* The shapes/dtypes of ``inputs``
|
||||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||||
so we will add ``const device float16_t* inp`` to the signature.
|
so we will add ``const device float16_t* inp`` to the signature.
|
||||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||||
in ``source``.
|
in ``source``.
|
||||||
* The keys and values of ``output_shapes`` and ``output_dtypes``
|
* The list of ``output_dtypes``
|
||||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||||
so we add ``device float16_t* out``.
|
so we add ``device float16_t* out``.
|
||||||
* Template parameters passed using ``template``
|
* Template parameters passed using ``template``
|
||||||
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function
|
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
||||||
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||||
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||||
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||||
@ -104,18 +106,20 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
|||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="myexp_strided",
|
name="myexp_strided",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
source=source
|
source=source
|
||||||
)
|
)
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs={"inp": a},
|
inputs=[a],
|
||||||
template={"T": mx.float32},
|
template=[("T", mx.float32)],
|
||||||
grid=(a.size, 1, 1),
|
grid=(a.size, 1, 1),
|
||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
output_shapes={"out": a.shape},
|
output_shapes=[a.shape],
|
||||||
output_dtypes={"out": a.dtype},
|
output_dtypes=[a.dtype],
|
||||||
ensure_row_contiguous=False,
|
ensure_row_contiguous=False,
|
||||||
)
|
)
|
||||||
return outputs["out"]
|
return outputs[0]
|
||||||
|
|
||||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||||
# make non-contiguous
|
# make non-contiguous
|
||||||
@ -243,17 +247,19 @@ First we'll implement the forward pass as a fused kernel:
|
|||||||
"""
|
"""
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="grid_sample",
|
name="grid_sample",
|
||||||
|
input_names=["x", "grid"],
|
||||||
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs={"x": x, "grid": grid},
|
inputs=[x, grid],
|
||||||
template={"T": x.dtype},
|
template=[("T", x.dtype)],
|
||||||
output_shapes={"out": out_shape},
|
output_shapes=[out_shape],
|
||||||
output_dtypes={"out": x.dtype},
|
output_dtypes=[x.dtype],
|
||||||
grid=(np.prod(out_shape), 1, 1),
|
grid=(np.prod(out_shape), 1, 1),
|
||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
)
|
)
|
||||||
return outputs["out"]
|
return outputs[0]
|
||||||
|
|
||||||
For a reasonably sized input such as:
|
For a reasonably sized input such as:
|
||||||
|
|
||||||
@ -389,6 +395,8 @@ We can then implement the backwards pass as follows:
|
|||||||
"""
|
"""
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="grid_sample_grad",
|
name="grid_sample_grad",
|
||||||
|
input_names=["x", "grid", "cotangent"],
|
||||||
|
output_names=["x_grad", "grid_grad"],
|
||||||
source=source,
|
source=source,
|
||||||
atomic_outputs=True,
|
atomic_outputs=True,
|
||||||
)
|
)
|
||||||
@ -398,15 +406,15 @@ We can then implement the backwards pass as follows:
|
|||||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||||
grid_size = B * gN * gM * C_padded
|
grid_size = B * gN * gM * C_padded
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs={"x": x, "grid": grid, "cotangent": cotangent},
|
inputs=[x, grid, cotangent],
|
||||||
template={"T": x.dtype},
|
template=[("T", x.dtype)],
|
||||||
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape},
|
output_shapes=[x.shape, grid.shape],
|
||||||
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype},
|
output_dtypes=[x.dtype, x.dtype],
|
||||||
grid=(grid_size, 1, 1),
|
grid=(grid_size, 1, 1),
|
||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
init_value=0,
|
init_value=0,
|
||||||
)
|
)
|
||||||
return outputs["x_grad"], outputs["grid_grad"]
|
return outputs[0], outputs[1]
|
||||||
|
|
||||||
There's an even larger speed up for the vjp:
|
There's an even larger speed up for the vjp:
|
||||||
|
|
||||||
|
241
mlx/fast.cpp
241
mlx/fast.cpp
@ -515,7 +515,7 @@ array scaled_dot_product_attention(
|
|||||||
const array& values,
|
const array& values,
|
||||||
const float scale,
|
const float scale,
|
||||||
const std::optional<array>& mask,
|
const std::optional<array>& mask,
|
||||||
const std::optional<int>& memory_efficient_threshold,
|
const std::optional<int> memory_efficient_threshold,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
for (const auto& tensor : {queries, keys, values}) {
|
for (const auto& tensor : {queries, keys, values}) {
|
||||||
if (tensor.ndim() != 4) {
|
if (tensor.ndim() != 4) {
|
||||||
@ -916,47 +916,23 @@ array affine_dequantize(
|
|||||||
return fallback({w, scales, biases})[0];
|
return fallback({w, scales, biases})[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
void validate_output_shapes(
|
|
||||||
std::map<std::string, std::vector<int>> output_shapes,
|
|
||||||
std::map<std::string, Dtype> output_dtypes) {
|
|
||||||
// Make sure output shapes and dtypes have the same keys
|
|
||||||
bool validated = true;
|
|
||||||
if (output_shapes.size() == 0) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[metal_kernel] Must specify at least one output.");
|
|
||||||
}
|
|
||||||
if (output_shapes.size() != output_dtypes.size()) {
|
|
||||||
validated = false;
|
|
||||||
} else {
|
|
||||||
for (const auto& kv : output_shapes) {
|
|
||||||
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
|
|
||||||
validated = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!validated) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void write_signature(
|
void write_signature(
|
||||||
std::string func_name,
|
std::string func_name,
|
||||||
std::string& source,
|
const std::string& source,
|
||||||
std::map<std::string, array>& inputs,
|
const std::vector<std::string>& input_names,
|
||||||
std::map<std::string, std::vector<int>>& output_shapes,
|
const std::vector<array>& inputs,
|
||||||
std::map<std::string, Dtype>& output_dtypes,
|
const std::vector<std::string>& output_names,
|
||||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
const std::vector<Dtype>& output_dtypes,
|
||||||
|
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||||
std::vector<CustomKernelShapeInfo>& shape_infos,
|
std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||||
bool atomic_outputs,
|
bool atomic_outputs,
|
||||||
std::ostringstream& kernel_source) {
|
std::ostringstream& kernel_source) {
|
||||||
// Auto-generate a function signature based on `template_args`
|
// Auto-generate a function signature based on `template_args`
|
||||||
// and the dtype/shape of the arrays passed as `inputs`.
|
// and the dtype/shape of the arrays passed as `inputs`.
|
||||||
if (template_args && template_args.value().size() > 0) {
|
if (!template_args.empty()) {
|
||||||
kernel_source << "template <";
|
kernel_source << "template <";
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto& [name, arg] : template_args.value()) {
|
for (const auto& [name, arg] : template_args) {
|
||||||
std::string param_type;
|
std::string param_type;
|
||||||
if (std::holds_alternative<int>(arg)) {
|
if (std::holds_alternative<int>(arg)) {
|
||||||
param_type = "int";
|
param_type = "int";
|
||||||
@ -1008,7 +984,9 @@ void write_signature(
|
|||||||
int index = 0;
|
int index = 0;
|
||||||
constexpr int max_constant_array_size = 8;
|
constexpr int max_constant_array_size = 8;
|
||||||
// Add inputs
|
// Add inputs
|
||||||
for (const auto& [name, arr] : inputs) {
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& name = input_names[i];
|
||||||
|
const auto& arr = inputs[i];
|
||||||
auto dtype = get_type_string(arr.dtype());
|
auto dtype = get_type_string(arr.dtype());
|
||||||
bool is_constant =
|
bool is_constant =
|
||||||
arr.is_available() && arr.size() < max_constant_array_size;
|
arr.is_available() && arr.size() < max_constant_array_size;
|
||||||
@ -1042,7 +1020,9 @@ void write_signature(
|
|||||||
shape_infos.push_back(shape_info);
|
shape_infos.push_back(shape_info);
|
||||||
}
|
}
|
||||||
// Add outputs
|
// Add outputs
|
||||||
for (const auto& [name, dtype] : output_dtypes) {
|
for (int i = 0; i < output_names.size(); ++i) {
|
||||||
|
const auto& name = output_names[i];
|
||||||
|
const auto& dtype = output_dtypes[i];
|
||||||
kernel_source << " device ";
|
kernel_source << " device ";
|
||||||
auto type_string = get_type_string(dtype);
|
auto type_string = get_type_string(dtype);
|
||||||
if (atomic_outputs) {
|
if (atomic_outputs) {
|
||||||
@ -1051,7 +1031,7 @@ void write_signature(
|
|||||||
kernel_source << type_string;
|
kernel_source << type_string;
|
||||||
}
|
}
|
||||||
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
|
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
|
||||||
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
|
if (index < inputs.size() + output_names.size() - 1 || attrs.size() > 0) {
|
||||||
kernel_source << "," << std::endl;
|
kernel_source << "," << std::endl;
|
||||||
} else {
|
} else {
|
||||||
kernel_source << ") {" << std::endl;
|
kernel_source << ") {" << std::endl;
|
||||||
@ -1073,7 +1053,8 @@ void write_signature(
|
|||||||
kernel_source << "}" << std::endl;
|
kernel_source << "}" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string write_template(std::map<std::string, TemplateArg>& template_args) {
|
std::string write_template(
|
||||||
|
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||||
std::ostringstream template_def;
|
std::ostringstream template_def;
|
||||||
template_def << "<";
|
template_def << "<";
|
||||||
int i = 0;
|
int i = 0;
|
||||||
@ -1094,107 +1075,115 @@ std::string write_template(std::map<std::string, TemplateArg>& template_args) {
|
|||||||
return template_def.str();
|
return template_def.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, array> MetalKernel::operator()(
|
MetalKernelFunction metal_kernel(
|
||||||
std::map<std::string, array>& inputs,
|
const std::string& name,
|
||||||
std::map<std::string, std::vector<int>> output_shapes,
|
const std::vector<std::string>& input_names,
|
||||||
std::map<std::string, Dtype> output_dtypes,
|
const std::vector<std::string>& output_names,
|
||||||
std::tuple<int, int, int> grid,
|
const std::string& source,
|
||||||
std::tuple<int, int, int> threadgroup,
|
const std::string& header /* = "" */,
|
||||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
bool ensure_row_contiguous /* = true */,
|
||||||
std::optional<float> init_value,
|
bool atomic_outputs /* = false */) {
|
||||||
bool verbose,
|
if (output_names.empty()) {
|
||||||
StreamOrDevice s_) {
|
|
||||||
validate_output_shapes(output_shapes, output_dtypes);
|
|
||||||
|
|
||||||
auto s = to_stream(s_);
|
|
||||||
if (s.device != Device::gpu) {
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[metal_kernel] MetalKernel only works on GPU.");
|
"[metal_kernel] Must specify at least one output.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostringstream func_name;
|
return [=](const std::vector<array>& inputs,
|
||||||
|
const std::vector<std::vector<int>>& output_shapes,
|
||||||
|
const std::vector<Dtype>& output_dtypes,
|
||||||
|
std::tuple<int, int, int> grid,
|
||||||
|
std::tuple<int, int, int> threadgroup,
|
||||||
|
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||||
|
template_args = {},
|
||||||
|
std::optional<float> init_value = std::nullopt,
|
||||||
|
bool verbose = false,
|
||||||
|
StreamOrDevice s_ = {}) {
|
||||||
|
if (inputs.size() != input_names.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||||
|
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||||
|
<< std::endl;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (output_shapes.size() != output_names.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||||
|
<< output_names.size() << " but got size " << output_shapes.size()
|
||||||
|
<< "." << std::endl;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (output_dtypes.size() != output_names.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||||
|
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||||
|
<< "." << std::endl;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
std::string template_def = "";
|
auto s = to_stream(s_);
|
||||||
bool needs_template = template_args && template_args.value().size() > 0;
|
if (s.device != Device::gpu) {
|
||||||
std::string hash_key = "";
|
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||||
if (needs_template) {
|
}
|
||||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
|
||||||
template_def = write_template(template_args.value());
|
|
||||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
|
||||||
hash_key.pop_back();
|
|
||||||
}
|
|
||||||
|
|
||||||
func_name << "custom_kernel_" << name_ << hash_key;
|
std::ostringstream func_name;
|
||||||
std::string kernel_name = func_name.str();
|
|
||||||
|
|
||||||
std::ostringstream kernel_source;
|
std::string template_def = "";
|
||||||
kernel_source << header_ << std::endl;
|
std::string hash_key = "";
|
||||||
|
if (!template_args.empty()) {
|
||||||
|
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||||
|
template_def = write_template(template_args);
|
||||||
|
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
||||||
|
hash_key.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
func_name << "custom_kernel_" << name << hash_key;
|
||||||
write_signature(
|
std::string kernel_name = func_name.str();
|
||||||
func_name.str(),
|
|
||||||
source_,
|
|
||||||
inputs,
|
|
||||||
output_shapes,
|
|
||||||
output_dtypes,
|
|
||||||
template_args,
|
|
||||||
shape_infos,
|
|
||||||
atomic_outputs_,
|
|
||||||
kernel_source);
|
|
||||||
|
|
||||||
if (needs_template) {
|
std::ostringstream kernel_source;
|
||||||
template_def = func_name.str() + template_def;
|
kernel_source << header << std::endl;
|
||||||
kernel_source << std::endl
|
|
||||||
<< "template [[host_name(\"" << kernel_name
|
|
||||||
<< "\")]] [[kernel]] decltype(" << template_def << ") "
|
|
||||||
<< template_def << ";" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (verbose) {
|
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||||
std::cout << "Generated source code for `" << name_ << "`:" << std::endl
|
write_signature(
|
||||||
<< "```" << std::endl
|
func_name.str(),
|
||||||
<< kernel_source.str() << std::endl
|
source,
|
||||||
<< "```" << std::endl;
|
input_names,
|
||||||
}
|
inputs,
|
||||||
|
output_names,
|
||||||
|
output_dtypes,
|
||||||
|
template_args,
|
||||||
|
shape_infos,
|
||||||
|
atomic_outputs,
|
||||||
|
kernel_source);
|
||||||
|
|
||||||
std::vector<array> in_arrs;
|
if (!template_args.empty()) {
|
||||||
for (const auto& kv : inputs) {
|
template_def = func_name.str() + template_def;
|
||||||
in_arrs.push_back(kv.second);
|
kernel_source << std::endl
|
||||||
}
|
<< "template [[host_name(\"" << kernel_name
|
||||||
|
<< "\")]] [[kernel]] decltype(" << template_def << ") "
|
||||||
|
<< template_def << ";" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::string> out_keys;
|
if (verbose) {
|
||||||
std::vector<std::vector<int>> out_shapes;
|
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||||
for (const auto& [name, shape] : output_shapes) {
|
<< "```" << std::endl
|
||||||
out_keys.push_back(name);
|
<< kernel_source.str() << std::endl
|
||||||
out_shapes.push_back(shape);
|
<< "```" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Dtype> out_dtypes;
|
return array::make_arrays(
|
||||||
for (const auto& kv : output_dtypes) {
|
output_shapes,
|
||||||
out_dtypes.push_back(kv.second);
|
output_dtypes,
|
||||||
}
|
std::make_shared<CustomKernel>(
|
||||||
|
s,
|
||||||
std::map<std::string, array> outputs;
|
kernel_name,
|
||||||
auto outputs_vec = array::make_arrays(
|
kernel_source.str(),
|
||||||
out_shapes,
|
grid,
|
||||||
out_dtypes,
|
threadgroup,
|
||||||
std::make_shared<CustomKernel>(
|
shape_infos,
|
||||||
s,
|
ensure_row_contiguous,
|
||||||
kernel_name,
|
init_value),
|
||||||
kernel_source.str(),
|
inputs);
|
||||||
grid,
|
};
|
||||||
threadgroup,
|
|
||||||
shape_infos,
|
|
||||||
ensure_row_contiguous_,
|
|
||||||
init_value),
|
|
||||||
in_arrs);
|
|
||||||
|
|
||||||
int i = 0;
|
|
||||||
for (const auto& key : out_keys) {
|
|
||||||
outputs.insert({key, outputs_vec[i]});
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
return outputs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
53
mlx/fast.h
53
mlx/fast.h
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
@ -39,7 +38,7 @@ array scaled_dot_product_attention(
|
|||||||
const array& values,
|
const array& values,
|
||||||
const float scale,
|
const float scale,
|
||||||
const std::optional<array>& mask = std::nullopt,
|
const std::optional<array>& mask = std::nullopt,
|
||||||
const std::optional<int>& memory_efficient_threshold = std::nullopt,
|
const std::optional<int> memory_efficient_threshold = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
std::tuple<array, array, array> affine_quantize(
|
std::tuple<array, array, array> affine_quantize(
|
||||||
@ -66,37 +65,25 @@ array affine_dequantize(
|
|||||||
|
|
||||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||||
|
|
||||||
class MetalKernel {
|
typedef std::function<std::vector<array>(
|
||||||
public:
|
const std::vector<array>&,
|
||||||
MetalKernel(
|
const std::vector<std::vector<int>>&,
|
||||||
const std::string& name,
|
const std::vector<Dtype>&,
|
||||||
const std::string& source,
|
std::tuple<int, int, int>,
|
||||||
const std::string& header = "",
|
std::tuple<int, int, int>,
|
||||||
bool ensure_row_contiguous = true,
|
std::vector<std::pair<std::string, TemplateArg>>,
|
||||||
bool atomic_outputs = false)
|
std::optional<float>,
|
||||||
: name_(name),
|
bool,
|
||||||
source_(source),
|
StreamOrDevice)>
|
||||||
header_(header),
|
MetalKernelFunction;
|
||||||
ensure_row_contiguous_(ensure_row_contiguous),
|
|
||||||
atomic_outputs_(atomic_outputs) {}
|
|
||||||
|
|
||||||
std::map<std::string, array> operator()(
|
MetalKernelFunction metal_kernel(
|
||||||
std::map<std::string, array>& inputs,
|
const std::string& name,
|
||||||
std::map<std::string, std::vector<int>> output_shapes,
|
const std::vector<std::string>& input_names,
|
||||||
std::map<std::string, Dtype> output_dtypes,
|
const std::vector<std::string>& output_names,
|
||||||
std::tuple<int, int, int> grid,
|
const std::string& source,
|
||||||
std::tuple<int, int, int> threadgroup,
|
const std::string& header = "",
|
||||||
std::optional<std::map<std::string, TemplateArg>> template_args =
|
bool ensure_row_contiguous = true,
|
||||||
std::nullopt,
|
bool atomic_outputs = false);
|
||||||
std::optional<float> init_value = std::nullopt,
|
|
||||||
bool verbose = false,
|
|
||||||
StreamOrDevice s = {});
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::string name_;
|
|
||||||
std::string source_;
|
|
||||||
std::string header_;
|
|
||||||
bool ensure_row_contiguous_;
|
|
||||||
bool atomic_outputs_;
|
|
||||||
};
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
@ -1425,8 +1425,8 @@ array where(
|
|||||||
array nan_to_num(
|
array nan_to_num(
|
||||||
const array& a,
|
const array& a,
|
||||||
float nan /* = 0.0f */,
|
float nan /* = 0.0f */,
|
||||||
const std::optional<float>& posinf_ /* = std::nullopt */,
|
const std::optional<float> posinf_ /* = std::nullopt */,
|
||||||
const std::optional<float>& neginf_ /* = std::nullopt */,
|
const std::optional<float> neginf_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
Dtype dtype = a.dtype();
|
Dtype dtype = a.dtype();
|
||||||
if (!issubdtype(dtype, inexact)) {
|
if (!issubdtype(dtype, inexact)) {
|
||||||
|
@ -416,8 +416,8 @@ array where(
|
|||||||
array nan_to_num(
|
array nan_to_num(
|
||||||
const array& a,
|
const array& a,
|
||||||
float nan = 0.0f,
|
float nan = 0.0f,
|
||||||
const std::optional<float>& posinf = std::nullopt,
|
const std::optional<float> posinf = std::nullopt,
|
||||||
const std::optional<float>& neginf = std::nullopt,
|
const std::optional<float> neginf = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** True if all elements in the array are true (or non-zero). **/
|
/** True if all elements in the array are true (or non-zero). **/
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/map.h>
|
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
|
#include <nanobind/stl/pair.h>
|
||||||
#include <nanobind/stl/string.h>
|
#include <nanobind/stl/string.h>
|
||||||
#include <nanobind/stl/tuple.h>
|
#include <nanobind/stl/tuple.h>
|
||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
@ -193,39 +193,130 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
array: The quantized version of ``w``
|
array: The quantized version of ``w``
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
nb::class_<fast::MetalKernel>(
|
m.def(
|
||||||
m,
|
|
||||||
"metal_kernel",
|
"metal_kernel",
|
||||||
|
[](const std::string& name,
|
||||||
|
const std::vector<std::string>& input_names,
|
||||||
|
const std::vector<std::string>& output_names,
|
||||||
|
const std::string& source,
|
||||||
|
const std::string& header,
|
||||||
|
bool ensure_row_contiguous,
|
||||||
|
bool atomic_outputs) {
|
||||||
|
auto kernel = fast::metal_kernel(
|
||||||
|
name,
|
||||||
|
input_names,
|
||||||
|
output_names,
|
||||||
|
source,
|
||||||
|
header,
|
||||||
|
ensure_row_contiguous,
|
||||||
|
atomic_outputs);
|
||||||
|
return nb::cpp_function(
|
||||||
|
[kernel = std::move(kernel)](
|
||||||
|
const std::vector<ScalarOrArray>& inputs_,
|
||||||
|
const std::vector<std::vector<int>>& output_shapes,
|
||||||
|
const std::vector<Dtype>& output_dtypes,
|
||||||
|
std::tuple<int, int, int> grid,
|
||||||
|
std::tuple<int, int, int> threadgroup,
|
||||||
|
const std::optional<
|
||||||
|
std::vector<std::pair<std::string, nb::object>>>&
|
||||||
|
template_args_ = std::nullopt,
|
||||||
|
std::optional<float> init_value = std::nullopt,
|
||||||
|
bool verbose = false,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
std::vector<array> inputs;
|
||||||
|
for (const auto& value : inputs_) {
|
||||||
|
inputs.push_back(to_array(value, std::nullopt));
|
||||||
|
}
|
||||||
|
std::vector<std::pair<std::string, fast::TemplateArg>>
|
||||||
|
template_args;
|
||||||
|
if (template_args_) {
|
||||||
|
for (const auto& [name, value] : template_args_.value()) {
|
||||||
|
// Handle bool, int and dtype template args
|
||||||
|
if (nb::isinstance<bool>(value)) {
|
||||||
|
bool bool_val = nb::cast<bool>(value);
|
||||||
|
template_args.emplace_back(name, bool_val);
|
||||||
|
} else if (nb::isinstance<int>(value)) {
|
||||||
|
int int_val = nb::cast<int>(value);
|
||||||
|
template_args.emplace_back(name, int_val);
|
||||||
|
} else if (nb::isinstance<Dtype>(value)) {
|
||||||
|
Dtype dtype = nb::cast<Dtype>(value);
|
||||||
|
template_args.emplace_back(name, dtype);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kernel(
|
||||||
|
inputs,
|
||||||
|
output_shapes,
|
||||||
|
output_dtypes,
|
||||||
|
grid,
|
||||||
|
threadgroup,
|
||||||
|
template_args,
|
||||||
|
init_value,
|
||||||
|
verbose,
|
||||||
|
s);
|
||||||
|
},
|
||||||
|
nb::kw_only(),
|
||||||
|
"inputs"_a,
|
||||||
|
"output_shapes"_a,
|
||||||
|
"output_dtypes"_a,
|
||||||
|
"grid"_a,
|
||||||
|
"threadgroup"_a,
|
||||||
|
"template"_a = nb::none(),
|
||||||
|
"init_value"_a = nb::none(),
|
||||||
|
"verbose"_a = false,
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||||
|
R"pbdoc(
|
||||||
|
Run the kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[array]): The inputs passed to the Metal kernel.
|
||||||
|
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||||
|
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||||
|
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||||
|
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||||
|
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||||
|
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||||
|
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||||
|
By default, output arrays are uninitialized. Default: ``None``.
|
||||||
|
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||||
|
when it is run. Default: ``False``.
|
||||||
|
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[array]: The list of output arrays.
|
||||||
|
)pbdoc");
|
||||||
|
},
|
||||||
|
"name"_a,
|
||||||
|
"input_names"_a,
|
||||||
|
"output_names"_a,
|
||||||
|
"source"_a,
|
||||||
|
"header"_a = "",
|
||||||
|
"ensure_row_contiguous"_a = true,
|
||||||
|
"atomic_outputs"_a = false,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A jit-compiled custom Metal kernel defined from a source string.
|
A jit-compiled custom Metal kernel defined from a source string.
|
||||||
)pbdoc")
|
|
||||||
.def(
|
|
||||||
nb::init<
|
|
||||||
const std::string&,
|
|
||||||
const std::string&,
|
|
||||||
const std::string&,
|
|
||||||
bool,
|
|
||||||
bool>(),
|
|
||||||
"name"_a,
|
|
||||||
"source"_a,
|
|
||||||
"header"_a = "",
|
|
||||||
"ensure_row_contiguous"_a = true,
|
|
||||||
"atomic_outputs"_a = false,
|
|
||||||
R"pbdoc(
|
|
||||||
Initialize a metal_kernel.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): Name for the kernel.
|
name (str): Name for the kernel.
|
||||||
source (str): Source code. This is the body of a function in Metal,
|
input_names (List[str]): The parameter names of the inputs in the
|
||||||
the function signature will be generated for you. The names of the inputs/outputs
|
function signature.
|
||||||
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
|
output_names (List[str]): The parameter names of the outputs in the
|
||||||
used when the kernel is called.
|
function signature.
|
||||||
header (str): Header source code to include before the main function.
|
source (str): Source code. This is the body of a function in Metal,
|
||||||
Useful for helper functions or includes that should live outside of the main function body.
|
the function signature will be automatically generated.
|
||||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
header (str): Header source code to include before the main function.
|
||||||
before the kernel runs. Default: ``True``.
|
Useful for helper functions or includes that should live outside of
|
||||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
the main function body.
|
||||||
e.g. ``device atomic<float>``. Default: ``False``.
|
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||||
|
before the kernel runs. Default: ``True``.
|
||||||
|
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||||
|
e.g. ``device atomic<float>``. Default: ``False``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Callable ``metal_kernel``.
|
Callable ``metal_kernel``.
|
||||||
|
|
||||||
@ -242,103 +333,23 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="myexp",
|
name="myexp",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
source=source
|
source=source
|
||||||
)
|
)
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs={"inp": a},
|
inputs=[a],
|
||||||
template={"T": mx.float32},
|
template=[("T", mx.float32)],
|
||||||
grid=(a.size, 1, 1),
|
grid=(a.size, 1, 1),
|
||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
output_shapes={"out": a.shape},
|
output_shapes=[a.shape],
|
||||||
output_dtypes={"out": a.dtype},
|
output_dtypes=[a.dtype],
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
return outputs["out"]
|
return outputs[0]
|
||||||
|
|
||||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
)pbdoc")
|
)pbdoc");
|
||||||
.def(
|
|
||||||
"__call__",
|
|
||||||
[](fast::MetalKernel& kernel,
|
|
||||||
std::map<std::string, ScalarOrArray>& inputs_,
|
|
||||||
std::map<std::string, std::vector<int>>& output_shapes,
|
|
||||||
std::map<std::string, Dtype>& output_dtypes,
|
|
||||||
std::tuple<int, int, int> grid,
|
|
||||||
std::tuple<int, int, int> threadgroup,
|
|
||||||
std::optional<std::map<std::string, nb::handle>> template_args_,
|
|
||||||
std::optional<float> init_value,
|
|
||||||
bool verbose,
|
|
||||||
StreamOrDevice s) {
|
|
||||||
std::map<std::string, array> inputs;
|
|
||||||
for (const auto& [name, value] : inputs_) {
|
|
||||||
auto arr = to_array(value, std::nullopt);
|
|
||||||
inputs.insert({name, arr});
|
|
||||||
}
|
|
||||||
std::map<std::string, fast::TemplateArg> template_args;
|
|
||||||
if (template_args_) {
|
|
||||||
for (const auto& [name, value] : template_args_.value()) {
|
|
||||||
// Handle bool, int and dtype template args
|
|
||||||
if (nb::isinstance<bool>(value)) {
|
|
||||||
bool bool_val = nb::cast<bool>(value);
|
|
||||||
template_args.insert({name, bool_val});
|
|
||||||
} else if (nb::isinstance<int>(value)) {
|
|
||||||
int int_val = nb::cast<int>(value);
|
|
||||||
template_args.insert({name, int_val});
|
|
||||||
} else if (nb::isinstance<Dtype>(value)) {
|
|
||||||
Dtype dtype = nb::cast<Dtype>(value);
|
|
||||||
template_args.insert({name, dtype});
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return kernel(
|
|
||||||
inputs,
|
|
||||||
output_shapes,
|
|
||||||
output_dtypes,
|
|
||||||
grid,
|
|
||||||
threadgroup,
|
|
||||||
template_args,
|
|
||||||
init_value,
|
|
||||||
verbose,
|
|
||||||
s);
|
|
||||||
},
|
|
||||||
nb::kw_only(),
|
|
||||||
"inputs"_a,
|
|
||||||
"output_shapes"_a,
|
|
||||||
"output_dtypes"_a,
|
|
||||||
"grid"_a,
|
|
||||||
"threadgroup"_a,
|
|
||||||
"template"_a = nb::none(),
|
|
||||||
"init_value"_a = nb::none(),
|
|
||||||
"verbose"_a = false,
|
|
||||||
"stream"_a = nb::none(),
|
|
||||||
nb::sig(
|
|
||||||
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
|
||||||
R"pbdoc(
|
|
||||||
Run the kernel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
|
|
||||||
The keys will be the names of the arguments to the kernel.
|
|
||||||
output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
|
|
||||||
output variable names to shapes. These will be added to the function signature.
|
|
||||||
output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
|
|
||||||
names to dtypes. Must have the same keys as ``output_shapes``.
|
|
||||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
|
||||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
|
||||||
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
|
|
||||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
|
||||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
|
||||||
By default, output arrays are uninitialized. Default: ``None``.
|
|
||||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
|
||||||
when it is run. Default: ``False``.
|
|
||||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
|
|
||||||
)pbdoc");
|
|
||||||
}
|
}
|
||||||
|
@ -562,20 +562,22 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
a = mx.random.normal(shape=(2, 2))
|
a = mx.random.normal(shape=(2, 2))
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="basic",
|
name="basic",
|
||||||
|
input_names=["a"],
|
||||||
|
output_names=["out1"],
|
||||||
source="""
|
source="""
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
out1[elem] = a[elem];
|
out1[elem] = a[elem];
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs={"a": a},
|
inputs=[a],
|
||||||
grid=(4, 1, 1),
|
grid=(4, 1, 1),
|
||||||
threadgroup=(2, 1, 1),
|
threadgroup=(2, 1, 1),
|
||||||
output_shapes={"out1": (2, 2)},
|
output_shapes=[(2, 2)],
|
||||||
output_dtypes={"out1": mx.float32},
|
output_dtypes=[mx.float32],
|
||||||
stream=mx.gpu,
|
stream=mx.gpu,
|
||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(out["out1"], a))
|
self.assertTrue(mx.allclose(out[0], a))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_args(self):
|
def test_custom_kernel_args(self):
|
||||||
@ -585,6 +587,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="arg_test",
|
name="arg_test",
|
||||||
|
input_names=["a", "b", "c", "d"],
|
||||||
|
output_names=["out1", "out2"],
|
||||||
source="""
|
source="""
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
T tmp = a[0];
|
T tmp = a[0];
|
||||||
@ -597,26 +601,26 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs={
|
inputs=[
|
||||||
"a": a,
|
a,
|
||||||
"b": mx.array([3, 4, 5]),
|
mx.array([3, 4, 5]),
|
||||||
"c": c,
|
c,
|
||||||
"d": 7.3,
|
7.3,
|
||||||
},
|
],
|
||||||
template={
|
template=[
|
||||||
"e": True,
|
("e", True),
|
||||||
"f": 3,
|
("f", 3),
|
||||||
"T": mx.float16,
|
("T", mx.float16),
|
||||||
},
|
],
|
||||||
grid=(6, 1, 1),
|
grid=(6, 1, 1),
|
||||||
threadgroup=(2, 1, 1),
|
threadgroup=(2, 1, 1),
|
||||||
output_shapes={"out1": (2, 2), "out2": (3, 2)},
|
output_shapes=[(2, 2), (3, 2)],
|
||||||
output_dtypes={"out1": mx.float32, "out2": mx.int32},
|
output_dtypes=[mx.float32, mx.int32],
|
||||||
stream=mx.gpu,
|
stream=mx.gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484)))
|
self.assertTrue(mx.allclose(out[0], mx.full((2, 2), 14.0484)))
|
||||||
self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32)))
|
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_strides(self):
|
def test_custom_kernel_strides(self):
|
||||||
@ -640,19 +644,21 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
for contig in [True, False]:
|
for contig in [True, False]:
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="myexp" + str(contig),
|
name="myexp" + str(contig),
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
source=source_contig if contig else source,
|
source=source_contig if contig else source,
|
||||||
ensure_row_contiguous=contig,
|
ensure_row_contiguous=contig,
|
||||||
)
|
)
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs={"inp": a},
|
inputs=[a],
|
||||||
template={"T": mx.float32},
|
template=[("T", mx.float32)],
|
||||||
grid=(a.size, 1, 1),
|
grid=(a.size, 1, 1),
|
||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
output_shapes={"out": a.shape},
|
output_shapes=[a.shape],
|
||||||
output_dtypes={"out": a.dtype},
|
output_dtypes=[a.dtype],
|
||||||
stream=mx.gpu,
|
stream=mx.gpu,
|
||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
|
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_helper(self):
|
def test_custom_kernel_helper(self):
|
||||||
@ -660,6 +666,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
a = mx.random.normal(shape=(2, 2))
|
a = mx.random.normal(shape=(2, 2))
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="helper",
|
name="helper",
|
||||||
|
input_names=["a"],
|
||||||
|
output_names=["out1"],
|
||||||
header="""
|
header="""
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T do_exp(T x) {
|
T do_exp(T x) {
|
||||||
@ -672,14 +680,14 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
out = kernel(
|
out = kernel(
|
||||||
inputs={"a": a},
|
inputs=[a],
|
||||||
grid=(4, 1, 1),
|
grid=(4, 1, 1),
|
||||||
threadgroup=(2, 1, 1),
|
threadgroup=(2, 1, 1),
|
||||||
output_shapes={"out1": (2, 2)},
|
output_shapes=[(2, 2)],
|
||||||
output_dtypes={"out1": mx.float32},
|
output_dtypes=[mx.float32],
|
||||||
stream=mx.gpu,
|
stream=mx.gpu,
|
||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
|
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user