mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Reduce specializations (#1607)
* start of reduce specializations * fix all reduce * fix many dims * fix * non-jit tests clear * cleanup instantiations * cpu merges * change dim specializations * optimize * fix jit * fix jit * use higher precision for integer sum+prod * fixes
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
@@ -338,17 +337,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& out) {
|
||||
const Dtype& out_type) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, out_type, op);
|
||||
return kernel_source.str();
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::reduce_utils();
|
||||
kernel_source += metal::reduce();
|
||||
kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -358,30 +357,31 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
const Dtype& in_type,
|
||||
const Dtype& out_type,
|
||||
const std::string& idx_t,
|
||||
int ndim /* = -1 */,
|
||||
int bm /* = -1 */,
|
||||
int bn /* = -1 */) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
|
||||
if (bm >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
|
||||
} else if (ndim >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
|
||||
} else {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t);
|
||||
}
|
||||
return kernel_source.str();
|
||||
return kernel_source;
|
||||
});
|
||||
auto st = d.get_kernel(kernel_name, lib);
|
||||
return st;
|
||||
|
||||
Reference in New Issue
Block a user