Reduce specializations (#1607)

* start of reduce specializations

* fix all reduce

* fix many dims

* fix

* non-jit tests clear

* cleanup instantiations

* cpu merges

* change dim specializations

* optimize

* fix jit

* fix jit

* use higher precision for integer sum+prod

* fixes
This commit is contained in:
Awni Hannun
2024-11-21 19:53:00 -08:00
committed by GitHub
parent dcca0d7477
commit 0c5eea226b
14 changed files with 733 additions and 406 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
@@ -338,17 +337,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& out) {
const Dtype& out_type) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, func_name, out_type, op);
return kernel_source.str();
auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_t + ">";
std::string kernel_source = metal::utils();
kernel_source += metal::reduce_utils();
kernel_source += metal::reduce();
kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -358,30 +357,31 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& in,
const array& out,
const Dtype& in_type,
const Dtype& out_type,
const std::string& idx_t,
int ndim /* = -1 */,
int bm /* = -1 */,
int bn /* = -1 */) {
auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_t + ">";
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
if (bm >= 0) {
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
} else if (ndim >= 0) {
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim);
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
} else {
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op);
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t);
}
return kernel_source.str();
return kernel_source;
});
auto st = d.get_kernel(kernel_name, lib);
return st;