fix jit reduce (#1395)

This commit is contained in:
Awni Hannun 2024-09-04 14:03:10 -07:00 committed by GitHub
parent 969337345f
commit 41c603d48a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 88 additions and 75 deletions

View File

@ -337,35 +337,36 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
const array& out,
int ndim /* = -1 */,
int bm /* = -1 */,
int bn /* = -1 */) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
std::vector<std::pair<std::string, std::string>> reduce_kernels = {
{"all_reduce", "allReduce"},
{"col_reduce_small", "colReduceSmall"},
{"col_reduce_looped", "colReduceLooped"},
{"row_reduce_small", "rowReduceSmall"},
{"row_reduce_looped", "rowReduceLooped"},
{"row_reduce_simple", "rowReduceSimple"}};
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
for (auto [func, name] : reduce_kernels) {
if (bm >= 0) {
kernel_source << get_template_definition(
name + "_" + lib_name, func, in_type, out_type, op);
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
} else if (ndim >= 0) {
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim);
} else {
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op);
}
lib = d.get_library(lib_name, kernel_source.str());
lib = d.get_library(kernel_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
auto st = d.get_kernel(kernel_name, lib);
return st;
}
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(

View File

@ -83,9 +83,13 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& in,
const array& out);
const array& out,
int ndim = -1,
int bm = -1,
int bn = -1);
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
metal::Device& d,

View File

@ -82,9 +82,9 @@
otype, \
op)
#define instantiate_init_reduce(name, otype, op) \
instantiate_kernel("init_reduce_" #name, \
init_reduce, \
#define instantiate_init_reduce(name, otype, op) \
instantiate_kernel("init_reduce_" #name, \
init_reduce, \
otype, op)
#define instantiate_init_reduce_helper(name, tname, type, op) \
@ -96,9 +96,9 @@ instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper
instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \
instantiate_kernel("allReduce_" #name, \
all_reduce, \
#define instantiate_all_reduce(name, itype, otype, op) \
instantiate_kernel("all_reduce_" #name, \
all_reduce, \
itype, otype, op)
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
@ -114,16 +114,16 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("colReduceSmall_" #dim "_reduce_" #name, \
col_reduce_small, \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("colReduceLooped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
@ -139,7 +139,7 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
instantiate_col_reduce_looped(name, itype, otype, op, 4)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_general(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
@ -149,32 +149,32 @@ instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("rowReduceSmall_" #dim "_reduce_" #name, \
row_reduce_small, \
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("rowReduceLooped_" #dim "_reduce_" #name, \
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, dim)
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \
instantiate_row_reduce_small(name, itype, otype, op, 2) \
instantiate_row_reduce_small(name, itype, otype, op, 3) \
instantiate_row_reduce_small(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
instantiate_kernel("rowReduceSimple_" #name, \
row_reduce_simple, \
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \
instantiate_row_reduce_small(name, itype, otype, op, 2) \
instantiate_row_reduce_small(name, itype, otype, op, 3) \
instantiate_row_reduce_small(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
instantiate_kernel("row_reduce_simple_" #name, \
row_reduce_simple, \
itype, otype, op)
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)

View File

@ -4,7 +4,7 @@ template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int NDIMS,
int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]],
@ -198,13 +198,7 @@ template <
* totals with a loop.
* 7. Write them to the output
*/
template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int BM = 8,
int BN = 128>
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
[[kernel]] void col_reduce_looped(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],

View File

@ -193,7 +193,7 @@ template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int NDIMS,
int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_small(
const device T* in [[buffer(0)]],
@ -306,7 +306,7 @@ template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int NDIMS,
int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_looped(
const device T* in [[buffer(0)]],

View File

@ -104,8 +104,12 @@ MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string&,
const std::string&,
const array&,
const array&) {
const array&,
int,
int,
int) {
return d.get_kernel(kernel_name);
}

View File

@ -255,8 +255,9 @@ void all_reduce_dispatch(
std::vector<array>& copies) {
// Set the kernel
std::ostringstream kname;
kname << "allReduce_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
const std::string func_name = "all_reduce";
kname << func_name << "_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
size_t in_size = in.size();
@ -309,9 +310,9 @@ void all_reduce_dispatch(
// 2nd pass
std::ostringstream kname_2nd_pass;
kname_2nd_pass << "allReduce_" << op_name << type_to_name(intermediate);
auto kernel_2nd_pass =
get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out);
kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate);
auto kernel_2nd_pass = get_reduce_kernel(
d, kname_2nd_pass.str(), func_name, op_name, intermediate, out);
compute_encoder->setComputePipelineState(kernel_2nd_pass);
size_t intermediate_size = n_rows;
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
@ -335,8 +336,10 @@ void row_reduce_small(
// Set the kernel
std::ostringstream kname;
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
kname << "rowReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
const std::string func_name = "row_reduce_small";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Figure out the grid dims
@ -370,8 +373,9 @@ void row_reduce_simple(
const Stream& s) {
// Set the kernel
std::ostringstream kname;
kname << "rowReduceSimple_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
const std::string func_name = "row_reduce_simple";
kname << func_name << "_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
// Figure out the grid dims
@ -407,8 +411,10 @@ void row_reduce_looped(
// Set the kernel
std::ostringstream kname;
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
kname << "rowReduceLooped_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
const std::string func_name = "row_reduce_looped";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Figure out the grid
@ -497,8 +503,10 @@ void strided_reduce_small(
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
kname << "colReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
const std::string func_name = "col_reduce_small";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Launch
@ -535,9 +543,11 @@ void strided_reduce_looped(
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
kname << "colReduceLooped_" << n << "_" << BM << "_" << BN << "_reduce_"
const std::string func_name = "col_reduce_looped";
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
<< op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
compute_encoder->setComputePipelineState(kernel);
// Launch