mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-29 21:11:16 +08:00
fix jit reduce (#1395)
This commit is contained in:
parent
969337345f
commit
41c603d48a
@ -337,35 +337,36 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
const array& out,
|
||||
int ndim /* = -1 */,
|
||||
int bm /* = -1 */,
|
||||
int bn /* = -1 */) {
|
||||
auto lib = d.get_library(kernel_name);
|
||||
if (lib == nullptr) {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::vector<std::pair<std::string, std::string>> reduce_kernels = {
|
||||
{"all_reduce", "allReduce"},
|
||||
{"col_reduce_small", "colReduceSmall"},
|
||||
{"col_reduce_looped", "colReduceLooped"},
|
||||
{"row_reduce_small", "rowReduceSmall"},
|
||||
{"row_reduce_looped", "rowReduceLooped"},
|
||||
{"row_reduce_simple", "rowReduceSimple"}};
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
for (auto [func, name] : reduce_kernels) {
|
||||
if (bm >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
name + "_" + lib_name, func, in_type, out_type, op);
|
||||
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
|
||||
} else if (ndim >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim);
|
||||
} else {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op);
|
||||
}
|
||||
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
lib = d.get_library(kernel_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
auto st = d.get_kernel(kernel_name, lib);
|
||||
return st;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
|
@ -83,9 +83,13 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
const array& out,
|
||||
int ndim = -1,
|
||||
int bm = -1,
|
||||
int bn = -1);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
|
@ -82,9 +82,9 @@
|
||||
otype, \
|
||||
op)
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
instantiate_kernel("init_reduce_" #name, \
|
||||
init_reduce, \
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
instantiate_kernel("init_reduce_" #name, \
|
||||
init_reduce, \
|
||||
otype, op)
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
@ -96,9 +96,9 @@ instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper
|
||||
instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_kernel("allReduce_" #name, \
|
||||
all_reduce, \
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_kernel("all_reduce_" #name, \
|
||||
all_reduce, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
@ -114,16 +114,16 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("colReduceSmall_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("colReduceLooped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
|
||||
|
||||
@ -139,7 +139,7 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 4)
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
@ -149,32 +149,32 @@ instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("rowReduceSmall_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("rowReduceLooped_" #dim "_reduce_" #name, \
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
|
||||
instantiate_kernel("rowReduceSimple_" #name, \
|
||||
row_reduce_simple, \
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
|
||||
instantiate_kernel("row_reduce_simple_" #name, \
|
||||
row_reduce_simple, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
|
@ -4,7 +4,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
@ -198,13 +198,7 @@ template <
|
||||
* totals with a loop.
|
||||
* 7. Write them to the output
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int BM = 8,
|
||||
int BN = 128>
|
||||
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
[[kernel]] void col_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
|
@ -193,7 +193,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
@ -306,7 +306,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
|
@ -104,8 +104,12 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const array&,
|
||||
const array&) {
|
||||
const array&,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
|
@ -255,8 +255,9 @@ void all_reduce_dispatch(
|
||||
std::vector<array>& copies) {
|
||||
// Set the kernel
|
||||
std::ostringstream kname;
|
||||
kname << "allReduce_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
|
||||
const std::string func_name = "all_reduce";
|
||||
kname << func_name << "_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t in_size = in.size();
|
||||
@ -309,9 +310,9 @@ void all_reduce_dispatch(
|
||||
|
||||
// 2nd pass
|
||||
std::ostringstream kname_2nd_pass;
|
||||
kname_2nd_pass << "allReduce_" << op_name << type_to_name(intermediate);
|
||||
auto kernel_2nd_pass =
|
||||
get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out);
|
||||
kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate);
|
||||
auto kernel_2nd_pass = get_reduce_kernel(
|
||||
d, kname_2nd_pass.str(), func_name, op_name, intermediate, out);
|
||||
compute_encoder->setComputePipelineState(kernel_2nd_pass);
|
||||
size_t intermediate_size = n_rows;
|
||||
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
||||
@ -335,8 +336,10 @@ void row_reduce_small(
|
||||
// Set the kernel
|
||||
std::ostringstream kname;
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
kname << "rowReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
|
||||
const std::string func_name = "row_reduce_small";
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Figure out the grid dims
|
||||
@ -370,8 +373,9 @@ void row_reduce_simple(
|
||||
const Stream& s) {
|
||||
// Set the kernel
|
||||
std::ostringstream kname;
|
||||
kname << "rowReduceSimple_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
|
||||
const std::string func_name = "row_reduce_simple";
|
||||
kname << func_name << "_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Figure out the grid dims
|
||||
@ -407,8 +411,10 @@ void row_reduce_looped(
|
||||
// Set the kernel
|
||||
std::ostringstream kname;
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
kname << "rowReduceLooped_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
|
||||
const std::string func_name = "row_reduce_looped";
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Figure out the grid
|
||||
@ -497,8 +503,10 @@ void strided_reduce_small(
|
||||
// Set the kernel
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
std::ostringstream kname;
|
||||
kname << "colReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
|
||||
const std::string func_name = "col_reduce_small";
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Launch
|
||||
@ -535,9 +543,11 @@ void strided_reduce_looped(
|
||||
// Set the kernel
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
std::ostringstream kname;
|
||||
kname << "colReduceLooped_" << n << "_" << BM << "_" << BN << "_reduce_"
|
||||
const std::string func_name = "col_reduce_looped";
|
||||
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
|
||||
<< op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Launch
|
||||
|
Loading…
Reference in New Issue
Block a user