mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Simplifications for MLX C (#1396)
* simplifications for MLX C * use vectors instead of map * update examples
This commit is contained in:
241
mlx/fast.cpp
241
mlx/fast.cpp
@@ -515,7 +515,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<int>& memory_efficient_threshold,
|
||||
const std::optional<int> memory_efficient_threshold,
|
||||
StreamOrDevice s) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
@@ -916,47 +916,23 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
void validate_output_shapes(
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes) {
|
||||
// Make sure output shapes and dtypes have the same keys
|
||||
bool validated = true;
|
||||
if (output_shapes.size() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
if (output_shapes.size() != output_dtypes.size()) {
|
||||
validated = false;
|
||||
} else {
|
||||
for (const auto& kv : output_shapes) {
|
||||
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
|
||||
validated = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!validated) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
|
||||
}
|
||||
}
|
||||
|
||||
void write_signature(
|
||||
std::string func_name,
|
||||
std::string& source,
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>>& output_shapes,
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs,
|
||||
std::ostringstream& kernel_source) {
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (template_args && template_args.value().size() > 0) {
|
||||
if (!template_args.empty()) {
|
||||
kernel_source << "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args.value()) {
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
@@ -1008,7 +984,9 @@ void write_signature(
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (const auto& [name, arr] : inputs) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
bool is_constant =
|
||||
arr.is_available() && arr.size() < max_constant_array_size;
|
||||
@@ -1042,7 +1020,9 @@ void write_signature(
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
// Add outputs
|
||||
for (const auto& [name, dtype] : output_dtypes) {
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source << " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
@@ -1051,7 +1031,7 @@ void write_signature(
|
||||
kernel_source << type_string;
|
||||
}
|
||||
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
|
||||
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
|
||||
if (index < inputs.size() + output_names.size() - 1 || attrs.size() > 0) {
|
||||
kernel_source << "," << std::endl;
|
||||
} else {
|
||||
kernel_source << ") {" << std::endl;
|
||||
@@ -1073,7 +1053,8 @@ void write_signature(
|
||||
kernel_source << "}" << std::endl;
|
||||
}
|
||||
|
||||
std::string write_template(std::map<std::string, TemplateArg>& template_args) {
|
||||
std::string write_template(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
@@ -1094,107 +1075,115 @@ std::string write_template(std::map<std::string, TemplateArg>& template_args) {
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
std::map<std::string, array> MetalKernel::operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s_) {
|
||||
validate_output_shapes(output_shapes, output_dtypes);
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header /* = "" */,
|
||||
bool ensure_row_contiguous /* = true */,
|
||||
bool atomic_outputs /* = false */) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] MetalKernel only works on GPU.");
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::ostringstream func_name;
|
||||
return [=](const std::vector<array>& inputs,
|
||||
const std::vector<std::vector<int>>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::string template_def = "";
|
||||
bool needs_template = template_args && template_args.value().size() > 0;
|
||||
std::string hash_key = "";
|
||||
if (needs_template) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args.value());
|
||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
||||
hash_key.pop_back();
|
||||
}
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
func_name << "custom_kernel_" << name_ << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
std::ostringstream func_name;
|
||||
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << header_ << std::endl;
|
||||
std::string template_def = "";
|
||||
std::string hash_key = "";
|
||||
if (!template_args.empty()) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args);
|
||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
||||
hash_key.pop_back();
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
write_signature(
|
||||
func_name.str(),
|
||||
source_,
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos,
|
||||
atomic_outputs_,
|
||||
kernel_source);
|
||||
func_name << "custom_kernel_" << name << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
|
||||
if (needs_template) {
|
||||
template_def = func_name.str() + template_def;
|
||||
kernel_source << std::endl
|
||||
<< "template [[host_name(\"" << kernel_name
|
||||
<< "\")]] [[kernel]] decltype(" << template_def << ") "
|
||||
<< template_def << ";" << std::endl;
|
||||
}
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << header << std::endl;
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name_ << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source.str() << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
write_signature(
|
||||
func_name.str(),
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos,
|
||||
atomic_outputs,
|
||||
kernel_source);
|
||||
|
||||
std::vector<array> in_arrs;
|
||||
for (const auto& kv : inputs) {
|
||||
in_arrs.push_back(kv.second);
|
||||
}
|
||||
if (!template_args.empty()) {
|
||||
template_def = func_name.str() + template_def;
|
||||
kernel_source << std::endl
|
||||
<< "template [[host_name(\"" << kernel_name
|
||||
<< "\")]] [[kernel]] decltype(" << template_def << ") "
|
||||
<< template_def << ";" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::string> out_keys;
|
||||
std::vector<std::vector<int>> out_shapes;
|
||||
for (const auto& [name, shape] : output_shapes) {
|
||||
out_keys.push_back(name);
|
||||
out_shapes.push_back(shape);
|
||||
}
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source.str() << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<Dtype> out_dtypes;
|
||||
for (const auto& kv : output_dtypes) {
|
||||
out_dtypes.push_back(kv.second);
|
||||
}
|
||||
|
||||
std::map<std::string, array> outputs;
|
||||
auto outputs_vec = array::make_arrays(
|
||||
out_shapes,
|
||||
out_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
kernel_name,
|
||||
kernel_source.str(),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous_,
|
||||
init_value),
|
||||
in_arrs);
|
||||
|
||||
int i = 0;
|
||||
for (const auto& key : out_keys) {
|
||||
outputs.insert({key, outputs_vec[i]});
|
||||
i++;
|
||||
}
|
||||
return outputs;
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
kernel_name,
|
||||
kernel_source.str(),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
inputs);
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
53
mlx/fast.h
53
mlx/fast.h
@@ -2,7 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/utils.h"
|
||||
@@ -39,7 +38,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask = std::nullopt,
|
||||
const std::optional<int>& memory_efficient_threshold = std::nullopt,
|
||||
const std::optional<int> memory_efficient_threshold = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::tuple<array, array, array> affine_quantize(
|
||||
@@ -66,37 +65,25 @@ array affine_dequantize(
|
||||
|
||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||
|
||||
class MetalKernel {
|
||||
public:
|
||||
MetalKernel(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
const std::string& header = "",
|
||||
bool ensure_row_contiguous = true,
|
||||
bool atomic_outputs = false)
|
||||
: name_(name),
|
||||
source_(source),
|
||||
header_(header),
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
atomic_outputs_(atomic_outputs) {}
|
||||
typedef std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<std::vector<int>>&,
|
||||
const std::vector<Dtype>&,
|
||||
std::tuple<int, int, int>,
|
||||
std::tuple<int, int, int>,
|
||||
std::vector<std::pair<std::string, TemplateArg>>,
|
||||
std::optional<float>,
|
||||
bool,
|
||||
StreamOrDevice)>
|
||||
MetalKernelFunction;
|
||||
|
||||
std::map<std::string, array> operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args =
|
||||
std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s = {});
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header = "",
|
||||
bool ensure_row_contiguous = true,
|
||||
bool atomic_outputs = false);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::string source_;
|
||||
std::string header_;
|
||||
bool ensure_row_contiguous_;
|
||||
bool atomic_outputs_;
|
||||
};
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -1425,8 +1425,8 @@ array where(
|
||||
array nan_to_num(
|
||||
const array& a,
|
||||
float nan /* = 0.0f */,
|
||||
const std::optional<float>& posinf_ /* = std::nullopt */,
|
||||
const std::optional<float>& neginf_ /* = std::nullopt */,
|
||||
const std::optional<float> posinf_ /* = std::nullopt */,
|
||||
const std::optional<float> neginf_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
Dtype dtype = a.dtype();
|
||||
if (!issubdtype(dtype, inexact)) {
|
||||
|
@@ -416,8 +416,8 @@ array where(
|
||||
array nan_to_num(
|
||||
const array& a,
|
||||
float nan = 0.0f,
|
||||
const std::optional<float>& posinf = std::nullopt,
|
||||
const std::optional<float>& neginf = std::nullopt,
|
||||
const std::optional<float> posinf = std::nullopt,
|
||||
const std::optional<float> neginf = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** True if all elements in the array are true (or non-zero). **/
|
||||
|
Reference in New Issue
Block a user