Simplifications for MLX C (#1396)

* simplifications for MLX C

* use vectors instead of map

* update examples
This commit is contained in:
Awni Hannun
2024-09-06 19:16:50 -07:00
committed by GitHub
parent 7cca1727af
commit ba3e913c7a
7 changed files with 334 additions and 331 deletions

View File

@@ -515,7 +515,7 @@ array scaled_dot_product_attention(
const array& values,
const float scale,
const std::optional<array>& mask,
const std::optional<int>& memory_efficient_threshold,
const std::optional<int> memory_efficient_threshold,
StreamOrDevice s) {
for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) {
@@ -916,47 +916,23 @@ array affine_dequantize(
return fallback({w, scales, biases})[0];
}
void validate_output_shapes(
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes) {
// Make sure output shapes and dtypes have the same keys
bool validated = true;
if (output_shapes.size() == 0) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
if (output_shapes.size() != output_dtypes.size()) {
validated = false;
} else {
for (const auto& kv : output_shapes) {
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
validated = false;
break;
}
}
}
if (!validated) {
throw std::invalid_argument(
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
}
}
void write_signature(
std::string func_name,
std::string& source,
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>>& output_shapes,
std::map<std::string, Dtype>& output_dtypes,
std::optional<std::map<std::string, TemplateArg>> template_args,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs,
std::ostringstream& kernel_source) {
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (template_args && template_args.value().size() > 0) {
if (!template_args.empty()) {
kernel_source << "template <";
int i = 0;
for (const auto& [name, arg] : template_args.value()) {
for (const auto& [name, arg] : template_args) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
@@ -1008,7 +984,9 @@ void write_signature(
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (const auto& [name, arr] : inputs) {
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
bool is_constant =
arr.is_available() && arr.size() < max_constant_array_size;
@@ -1042,7 +1020,9 @@ void write_signature(
shape_infos.push_back(shape_info);
}
// Add outputs
for (const auto& [name, dtype] : output_dtypes) {
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source << " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
@@ -1051,7 +1031,7 @@ void write_signature(
kernel_source << type_string;
}
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
if (index < inputs.size() + output_names.size() - 1 || attrs.size() > 0) {
kernel_source << "," << std::endl;
} else {
kernel_source << ") {" << std::endl;
@@ -1073,7 +1053,8 @@ void write_signature(
kernel_source << "}" << std::endl;
}
std::string write_template(std::map<std::string, TemplateArg>& template_args) {
std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
@@ -1094,107 +1075,115 @@ std::string write_template(std::map<std::string, TemplateArg>& template_args) {
return template_def.str();
}
std::map<std::string, array> MetalKernel::operator()(
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args,
std::optional<float> init_value,
bool verbose,
StreamOrDevice s_) {
validate_output_shapes(output_shapes, output_dtypes);
auto s = to_stream(s_);
if (s.device != Device::gpu) {
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] MetalKernel only works on GPU.");
"[metal_kernel] Must specify at least one output.");
}
std::ostringstream func_name;
return [=](const std::vector<array>& inputs,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
std::string template_def = "";
bool needs_template = template_args && template_args.value().size() > 0;
std::string hash_key = "";
if (needs_template) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args.value());
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
}
func_name << "custom_kernel_" << name_ << hash_key;
std::string kernel_name = func_name.str();
std::ostringstream func_name;
std::ostringstream kernel_source;
kernel_source << header_ << std::endl;
std::string template_def = "";
std::string hash_key = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
std::vector<CustomKernelShapeInfo> shape_infos;
write_signature(
func_name.str(),
source_,
inputs,
output_shapes,
output_dtypes,
template_args,
shape_infos,
atomic_outputs_,
kernel_source);
func_name << "custom_kernel_" << name << hash_key;
std::string kernel_name = func_name.str();
if (needs_template) {
template_def = func_name.str() + template_def;
kernel_source << std::endl
<< "template [[host_name(\"" << kernel_name
<< "\")]] [[kernel]] decltype(" << template_def << ") "
<< template_def << ";" << std::endl;
}
std::ostringstream kernel_source;
kernel_source << header << std::endl;
if (verbose) {
std::cout << "Generated source code for `" << name_ << "`:" << std::endl
<< "```" << std::endl
<< kernel_source.str() << std::endl
<< "```" << std::endl;
}
std::vector<CustomKernelShapeInfo> shape_infos;
write_signature(
func_name.str(),
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
shape_infos,
atomic_outputs,
kernel_source);
std::vector<array> in_arrs;
for (const auto& kv : inputs) {
in_arrs.push_back(kv.second);
}
if (!template_args.empty()) {
template_def = func_name.str() + template_def;
kernel_source << std::endl
<< "template [[host_name(\"" << kernel_name
<< "\")]] [[kernel]] decltype(" << template_def << ") "
<< template_def << ";" << std::endl;
}
std::vector<std::string> out_keys;
std::vector<std::vector<int>> out_shapes;
for (const auto& [name, shape] : output_shapes) {
out_keys.push_back(name);
out_shapes.push_back(shape);
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source.str() << std::endl
<< "```" << std::endl;
}
std::vector<Dtype> out_dtypes;
for (const auto& kv : output_dtypes) {
out_dtypes.push_back(kv.second);
}
std::map<std::string, array> outputs;
auto outputs_vec = array::make_arrays(
out_shapes,
out_dtypes,
std::make_shared<CustomKernel>(
s,
kernel_name,
kernel_source.str(),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous_,
init_value),
in_arrs);
int i = 0;
for (const auto& key : out_keys) {
outputs.insert({key, outputs_vec[i]});
i++;
}
return outputs;
return array::make_arrays(
output_shapes,
output_dtypes,
std::make_shared<CustomKernel>(
s,
kernel_name,
kernel_source.str(),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
inputs);
};
}
} // namespace mlx::core::fast

View File

@@ -2,7 +2,6 @@
#pragma once
#include <map>
#include <optional>
#include "mlx/utils.h"
@@ -39,7 +38,7 @@ array scaled_dot_product_attention(
const array& values,
const float scale,
const std::optional<array>& mask = std::nullopt,
const std::optional<int>& memory_efficient_threshold = std::nullopt,
const std::optional<int> memory_efficient_threshold = std::nullopt,
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(
@@ -66,37 +65,25 @@ array affine_dequantize(
typedef std::variant<int, bool, Dtype> TemplateArg;
class MetalKernel {
public:
MetalKernel(
const std::string& name,
const std::string& source,
const std::string& header = "",
bool ensure_row_contiguous = true,
bool atomic_outputs = false)
: name_(name),
source_(source),
header_(header),
ensure_row_contiguous_(ensure_row_contiguous),
atomic_outputs_(atomic_outputs) {}
typedef std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<std::vector<int>>&,
const std::vector<Dtype>&,
std::tuple<int, int, int>,
std::tuple<int, int, int>,
std::vector<std::pair<std::string, TemplateArg>>,
std::optional<float>,
bool,
StreamOrDevice)>
MetalKernelFunction;
std::map<std::string, array> operator()(
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args =
std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s = {});
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header = "",
bool ensure_row_contiguous = true,
bool atomic_outputs = false);
private:
std::string name_;
std::string source_;
std::string header_;
bool ensure_row_contiguous_;
bool atomic_outputs_;
};
} // namespace mlx::core::fast

View File

@@ -1425,8 +1425,8 @@ array where(
array nan_to_num(
const array& a,
float nan /* = 0.0f */,
const std::optional<float>& posinf_ /* = std::nullopt */,
const std::optional<float>& neginf_ /* = std::nullopt */,
const std::optional<float> posinf_ /* = std::nullopt */,
const std::optional<float> neginf_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
Dtype dtype = a.dtype();
if (!issubdtype(dtype, inexact)) {

View File

@@ -416,8 +416,8 @@ array where(
array nan_to_num(
const array& a,
float nan = 0.0f,
const std::optional<float>& posinf = std::nullopt,
const std::optional<float>& neginf = std::nullopt,
const std::optional<float> posinf = std::nullopt,
const std::optional<float> neginf = std::nullopt,
StreamOrDevice s = {});
/** True if all elements in the array are true (or non-zero). **/