Improve metal elementwise kernels (#2247)

* improve metal elementwise kernels

* compile and copy

* fix jit
This commit is contained in:
Awni Hannun 2025-06-06 11:37:40 -07:00 committed by GitHub
parent a5ac9244c4
commit c6a20b427a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 412 additions and 174 deletions

View File

@ -31,13 +31,13 @@ std::string get_kernel_name(
kname = "ss"; kname = "ss";
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
kname = (large ? "sv2" : "sv"); kname = "sv";
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
kname = (large ? "vs2" : "vs"); kname = "vs";
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
kname = (large ? "vv2" : "vv"); kname = "vv";
break; break;
case BinaryOpType::General: case BinaryOpType::General:
kname = "g"; kname = "g";
@ -51,6 +51,13 @@ std::string get_kernel_name(
} }
break; break;
} }
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
if (large) {
kname += "2";
} else if (work_per_thread > 1) {
kname += "n";
}
}
concatenate(kname, "_", op, type_to_name(a)); concatenate(kname, "_", op, type_to_name(a));
return kname; return kname;
} }
@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2; work_per_thread = large ? 4 : 2;
} else { } else {
large = out.data_size() > UINT32_MAX; large = out.data_size() > UINT32_MAX;
work_per_thread = get_work_per_thread(a.dtype()); work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
} }
std::string kernel_name = std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);

View File

@ -278,7 +278,21 @@ void Compiled::eval_gpu(
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
/* use_big_index = */ false, /* use_big_index = */ false,
/* work_per_thread = */ 1);
if (work_per_thread > 1) {
build_kernel(
kernel,
kernel_lib_ + "_contiguous_n",
inputs_,
outputs_,
tape_,
is_constant_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread); /* work_per_thread = */ work_per_thread);
}
build_kernel( build_kernel(
kernel, kernel,
kernel_lib_ + "_contiguous_large", kernel_lib_ + "_contiguous_large",
@ -358,12 +372,20 @@ void Compiled::eval_gpu(
int ndim = shape.size(); int ndim = shape.size();
bool dynamic = ndim >= 8; bool dynamic = ndim >= 8;
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
int work_per_thread = 1;
if (!contiguous) { if (!contiguous) {
if (dynamic) { if (dynamic) {
kernel_name += "dynamic"; kernel_name += "dynamic";
} else { } else {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(shape.size());
} }
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
} else {
work_per_thread =
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
if (work_per_thread > 1 && !large) {
kernel_name += "_n";
}
} }
if (large) { if (large) {
kernel_name += "_large"; kernel_name += "_large";
@ -420,7 +442,6 @@ void Compiled::eval_gpu(
// Launch the kernel // Launch the kernel
if (contiguous) { if (contiguous) {
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
MTL::Size group_dims( MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
@ -433,7 +454,6 @@ void Compiled::eval_gpu(
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1); size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2; int pow2;

View File

@ -55,10 +55,10 @@ void copy_gpu_inplace(
std::string kernel_name; std::string kernel_name;
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
kernel_name = (large ? "s2" : "s"); kernel_name = large ? "s2" : "s";
break; break;
case CopyType::Vector: case CopyType::Vector:
kernel_name = (large ? "v2" : "v"); kernel_name = large ? "v2" : "v";
break; break;
case CopyType::General: case CopyType::General:
kernel_name = "g"; kernel_name = "g";
@ -85,7 +85,10 @@ void copy_gpu_inplace(
} }
} }
} else { } else {
work_per_thread = get_work_per_thread(in.dtype()); work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
if (work_per_thread > 1) {
kernel_name += "n";
}
} }
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
} }
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX; bool large = out.data_size() > UINT32_MAX;
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
type_to_name(val) + type_to_name(out); concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, val, out); auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0); compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = ceildiv(out.data_size(), work_per_thread); size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {

View File

@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel(
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary()); concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source += kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op); get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
}
kernel_source += kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition( kernel_source += get_template_definition(
@ -59,11 +63,8 @@ void append_binary_kernels(
Dtype out_type, Dtype out_type,
const std::string op, const std::string op,
std::string& kernel_source) { std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"}, {"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"vs2", "binary_vs2"}, {"vs2", "binary_vs2"},
{"sv2", "binary_sv2"}, {"sv2", "binary_sv2"},
{"vv2", "binary_vv2"}, {"vv2", "binary_vv2"},
@ -78,6 +79,22 @@ void append_binary_kernels(
kernel_source += kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
} }
kernel_source += get_template_definition(
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source += get_template_definition(
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
kernel_source += get_template_definition(
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
kernel_source += get_template_definition(
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
}
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(
@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type); auto t_str = get_type_string(type);
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"}, {"v2", "ternary_v2"},
{"g1large", "ternary_g_nd1"}, {"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"}, {"g2large", "ternary_g_nd2"},
@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source += kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op); get_template_definition(name + "_" + lib_name, func, t_str, op);
} }
if (get_work_per_thread(type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
}
kernel_source +=
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(
@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source += metal::copy(); kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype()); auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype()); auto out_type = get_type_string(out.dtype());
kernel_source += kernel_source += get_template_definition(
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); "s_" + lib_name, "copy_s", in_type, out_type, 1);
kernel_source += kernel_source +=
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
kernel_source += kernel_source += get_template_definition(
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); "v_" + lib_name, "copy_v", in_type, out_type, 1);
kernel_source += kernel_source +=
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
if (get_work_per_thread(out.dtype()) > 1) {
kernel_source += get_template_definition(
"sn_" + lib_name, "copy_s", in_type, out_type);
kernel_source += get_template_definition(
"vn_" + lib_name, "copy_v", in_type, out_type);
}
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(

View File

@ -17,9 +17,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]); c[index + i] = Op()(a[0], b[index + i]);
} }
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -30,9 +36,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]); c[index + i] = Op()(a[index + i], b[0]);
} }
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -43,9 +55,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]); c[index + i] = Op()(a[index + i], b[index + i]);
} }
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -57,9 +75,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]); c[offset + i] = Op()(a[0], b[offset + i]);
} }
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -71,9 +95,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]); c[offset + i] = Op()(a[offset + i], b[0]);
} }
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -85,9 +115,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]); c[offset + i] = Op()(a[offset + i], b[offset + i]);
} }
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
}
} }
template <typename T, typename U, typename Op, typename IdxT = int64_t> template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@ -9,11 +9,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h" #include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@ -26,15 +31,19 @@
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_work_per_thread(op, tname, itype, otype)
#define instantiate_binary_integer(op) \ #define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \ instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \ instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \ instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) instantiate_binary_base(op, int64, int64_t, int64_t)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \
@ -44,7 +53,7 @@
#define instantiate_binary_types(op) \ #define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_integer(op) \ instantiate_binary_integer(op) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
instantiate_binary_float(op) instantiate_binary_float(op)
#define instantiate_binary_types_bool(op) \ #define instantiate_binary_types_bool(op) \
@ -52,15 +61,15 @@
instantiate_binary_all(op, uint8, uint8_t, bool) \ instantiate_binary_all(op, uint8, uint8_t, bool) \
instantiate_binary_all(op, uint16, uint16_t, bool) \ instantiate_binary_all(op, uint16, uint16_t, bool) \
instantiate_binary_all(op, uint32, uint32_t, bool) \ instantiate_binary_all(op, uint32, uint32_t, bool) \
instantiate_binary_all(op, uint64, uint64_t, bool) \ instantiate_binary_base(op, uint64, uint64_t, bool) \
instantiate_binary_all(op, int8, int8_t, bool) \ instantiate_binary_all(op, int8, int8_t, bool) \
instantiate_binary_all(op, int16, int16_t, bool) \ instantiate_binary_all(op, int16, int16_t, bool) \
instantiate_binary_all(op, int32, int32_t, bool) \ instantiate_binary_all(op, int32, int32_t, bool) \
instantiate_binary_all(op, int64, int64_t, bool) \ instantiate_binary_base(op, int64, int64_t, bool) \
instantiate_binary_all(op, float16, half, bool) \ instantiate_binary_all(op, float16, half, bool) \
instantiate_binary_all(op, float32, float, bool) \ instantiate_binary_all(op, float32, float, bool) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
instantiate_binary_all(op, complex64, complex64_t, bool) instantiate_binary_base(op, complex64, complex64_t, bool)
instantiate_binary_types(Add) instantiate_binary_types(Add)
instantiate_binary_types(Divide) instantiate_binary_types(Divide)
@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual) instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp) instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum) instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum) instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply) instantiate_binary_types(Multiply)
@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2)
instantiate_binary_all(NaNEqual, float16, half, bool) instantiate_binary_all(NaNEqual, float16, half, bool)
instantiate_binary_all(NaNEqual, float32, float, bool) instantiate_binary_all(NaNEqual, float32, float, bool)
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool) instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool) instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
instantiate_binary_all(LogicalOr, bool_, bool, bool) instantiate_binary_all(LogicalOr, bool_, bool, bool)
instantiate_binary_all(LogicalAnd, bool_, bool, bool) instantiate_binary_all(LogicalAnd, bool_, bool, bool)

View File

@ -21,11 +21,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[0], b[index + i]); auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0]; c[index + i] = out[0];
d[index + i] = out[1]; d[index + i] = out[1];
} }
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -37,11 +45,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[index + i], b[0]); auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0]; c[index + i] = out[0];
d[index + i] = out[1]; d[index + i] = out[1];
} }
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -53,11 +69,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[index + i], b[index + i]); auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0]; c[index + i] = out[0];
d[index + i] = out[1]; d[index + i] = out[1];
} }
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -69,12 +93,20 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[0], b[offset + i]); auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0]; c[offset + i] = out[0];
d[offset + i] = out[1]; d[offset + i] = out[1];
} }
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -86,12 +118,20 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[offset + i], b[0]); auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0]; c[offset + i] = out[0];
d[offset + i] = out[1]; d[offset + i] = out[1];
} }
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -103,12 +143,20 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[offset + i], b[offset + i]); auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0]; c[offset + i] = out[0];
d[offset + i] = out[1]; d[offset + i] = out[1];
} }
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
} }
template <typename T, typename U, typename Op, typename IdxT = int64_t> template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@ -7,11 +7,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h" #include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op)
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@ -24,6 +29,10 @@
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_work_per_thread(op, tname, itype, otype)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \
instantiate_binary_all(op, float32, float, float) \ instantiate_binary_all(op, float32, float, float) \
@ -34,12 +43,12 @@
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \ instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \ instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \ instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) \ instantiate_binary_base(op, int64, int64_t, int64_t) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ instantiate_binary_base(op, complex64, complex64_t, complex64_t) \
instantiate_binary_float(op) instantiate_binary_float(op)
instantiate_binary_types(DivMod) // clang-format on instantiate_binary_types(DivMod) // clang-format on

View File

@ -1,53 +1,77 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s( [[kernel]] void copy_s(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[0]); dst[index + i] = static_cast<U>(src[0]);
} }
} }
}
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v( [[kernel]] void copy_v(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[index + i]); dst[index + i] = static_cast<U>(src[index + i]);
} }
} }
}
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s2( [[kernel]] void copy_s2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[0]); dst[offset + i] = static_cast<U>(src[0]);
} }
} }
}
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v2( [[kernel]] void copy_v2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]); dst[offset + i] = static_cast<U>(src[offset + i]);
} }
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
}
} }
template <typename T, typename U, typename IdxT = int64_t> template <typename T, typename U, typename IdxT = int64_t>

View File

@ -4,9 +4,13 @@
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/copy.h" #include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_work_per_thread(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ instantiate_kernel("vn_copy" #tname, copy_v, itype, otype)
#define instantiate_copy_base(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
@ -18,6 +22,10 @@
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy_base(tname, itype, otype) \
instantiate_copy_work_per_thread(tname, itype, otype)
#define instantiate_copy_same(tname, type) \ #define instantiate_copy_same(tname, type) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
@ -42,15 +50,15 @@
instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \ instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \ instantiate_copy_base(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \ instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \ instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \ instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \ instantiate_copy_base(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \ instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t) instantiate_copy_base(itname ##complex64, itype, complex64_t)
instantiate_copy_itype(bool_, bool) instantiate_copy_itype(bool_, bool)
instantiate_copy_itype(uint8, uint8_t) instantiate_copy_itype(uint8, uint8_t)

View File

@ -9,9 +9,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
} }
} else {
for (int i = 0; i < N; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
}
} }
template <typename T, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename Op, int N = WorkPerThread<T>::n>
@ -23,10 +29,16 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
} }
} else {
for (int i = 0; i < N; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
}
} }
template <typename T, typename Op, typename IdxT = int64_t> template <typename T, typename Op, typename IdxT = int64_t>

View File

@ -8,8 +8,8 @@
#include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h" #include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \ #define instantiate_ternary_base(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
@ -20,19 +20,23 @@
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
instantiate_ternary_base(op, tname, type)
#define instantiate_ternary_types(op) \ #define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, bool_, bool) \
instantiate_ternary_all(op, uint8, uint8_t) \ instantiate_ternary_all(op, uint8, uint8_t) \
instantiate_ternary_all(op, uint16, uint16_t) \ instantiate_ternary_all(op, uint16, uint16_t) \
instantiate_ternary_all(op, uint32, uint32_t) \ instantiate_ternary_all(op, uint32, uint32_t) \
instantiate_ternary_all(op, uint64, uint64_t) \ instantiate_ternary_base(op, uint64, uint64_t) \
instantiate_ternary_all(op, int8, int8_t) \ instantiate_ternary_all(op, int8, int8_t) \
instantiate_ternary_all(op, int16, int16_t) \ instantiate_ternary_all(op, int16, int16_t) \
instantiate_ternary_all(op, int32, int32_t) \ instantiate_ternary_all(op, int32, int32_t) \
instantiate_ternary_all(op, int64, int64_t) \ instantiate_ternary_base(op, int64, int64_t) \
instantiate_ternary_all(op, float16, half) \ instantiate_ternary_all(op, float16, half) \
instantiate_ternary_all(op, float32, float) \ instantiate_ternary_all(op, float32, float) \
instantiate_ternary_all(op, bfloat16, bfloat16_t) \ instantiate_ternary_all(op, bfloat16, bfloat16_t) \
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on instantiate_ternary_base(op, complex64, complex64_t) // clang-format on
instantiate_ternary_types(Select) instantiate_ternary_types(Select)

View File

@ -7,9 +7,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
out[index + i] = Op()(in[index + i]); out[index + i] = Op()(in[index + i]);
} }
} else {
for (int i = 0; i < N; ++i) {
out[index + i] = Op()(in[index + i]);
}
}
} }
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
@ -19,10 +25,16 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
out[offset + i] = Op()(in[offset + i]); out[offset + i] = Op()(in[offset + i]);
} }
} else {
for (int i = 0; i < N; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
}
} }
template < template <

View File

@ -5,17 +5,27 @@
#include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h" #include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ #define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op)
#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \ instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \ instantiate_kernel( \
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) "gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type)
#define instantiate_unary_all_same(op, tname, type) \ #define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type) instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_base_same(op, tname, type) \
instantiate_unary_base(op, tname, tname, type, type)
#define instantiate_unary_float(op) \ #define instantiate_unary_float(op) \
instantiate_unary_all_same(op, float16, half) \ instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, float32, float) \
@ -25,11 +35,11 @@
instantiate_unary_all_same(op, uint8, uint8_t) \ instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \ instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \ instantiate_unary_all_same(op, uint32, uint32_t) \
instantiate_unary_all_same(op, uint64, uint64_t) \ instantiate_unary_base_same(op, uint64, uint64_t) \
instantiate_unary_all_same(op, int8, int8_t) \ instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \ instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \ instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_all_same(op, int64, int64_t) instantiate_unary_base_same(op, int64, int64_t)
#define instantiate_unary_types(op) \ #define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \ instantiate_unary_all_same(op, bool_, bool) \
@ -68,29 +78,29 @@ instantiate_unary_float(Tanh)
instantiate_unary_float(Round) instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert) instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t) instantiate_unary_base_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(ArcCos, complex64, complex64_t) instantiate_unary_base_same(ArcCos, complex64, complex64_t)
instantiate_unary_all_same(ArcSin, complex64, complex64_t) instantiate_unary_base_same(ArcSin, complex64, complex64_t)
instantiate_unary_all_same(ArcTan, complex64, complex64_t) instantiate_unary_base_same(ArcTan, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_base_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_base_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_base_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_base_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t) instantiate_unary_base_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log1p, complex64, complex64_t) instantiate_unary_base_same(Log1p, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t) instantiate_unary_base_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_base_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_base_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_base_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_base_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t) instantiate_unary_base_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Square, complex64, complex64_t) instantiate_unary_base_same(Square, complex64, complex64_t)
instantiate_unary_all_same(Sqrt, complex64, complex64_t) instantiate_unary_base_same(Sqrt, complex64, complex64_t)
instantiate_unary_all_same(Rsqrt, complex64, complex64_t) instantiate_unary_base_same(Rsqrt, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t) instantiate_unary_base_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t) instantiate_unary_base_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t) instantiate_unary_base_same(Round, complex64, complex64_t)
instantiate_unary_all(Real, complex64, float32, complex64_t, float) instantiate_unary_base(Real, complex64, float32, complex64_t, float)
instantiate_unary_all(Imag, complex64, float32, complex64_t, float) instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on

View File

@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
work_per_thread = large ? 4 : 2; work_per_thread = large ? 4 : 2;
} else { } else {
large = out.data_size() > INT32_MAX; large = out.data_size() > INT32_MAX;
work_per_thread = get_work_per_thread(b.dtype()); work_per_thread = get_work_per_thread(b.dtype(), out.data_size());
} }
std::string kernel_name; std::string kernel_name;
if (topt == TernaryOpType::General) { if (topt == TernaryOpType::General) {
@ -60,6 +60,8 @@ void ternary_op_gpu_inplace(
} }
} else if (large) { } else if (large) {
kernel_name = "v2"; kernel_name = "v2";
} else if (work_per_thread > 1) {
kernel_name = "vn";
} else { } else {
kernel_name = "v"; kernel_name = "v";
} }

View File

@ -43,8 +43,8 @@ void unary_op_gpu_inplace(
int work_per_thread; int work_per_thread;
std::string kernel_name; std::string kernel_name;
if (contig) { if (contig) {
work_per_thread = get_work_per_thread(in.dtype()); work_per_thread = get_work_per_thread(in.dtype(), in.data_size());
kernel_name = (large ? "v2" : "v"); kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v"));
} else { } else {
work_per_thread = large ? 4 : 1; work_per_thread = large ? 4 : 1;
kernel_name = "gn" + std::to_string(work_per_thread); kernel_name = "gn" + std::to_string(work_per_thread);

View File

@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) {
inline int get_work_per_thread(Dtype dtype) { inline int get_work_per_thread(Dtype dtype) {
return std::max(1, 8 / dtype.size()); return std::max(1, 8 / dtype.size());
} }
inline int get_work_per_thread(Dtype dtype, size_t size) {
constexpr size_t wpt_threshold = 1 << 16;
return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());
}
inline size_t ceildiv(size_t n, size_t m) { inline size_t ceildiv(size_t n, size_t m) {
return (n + m - 1) / m; return (n + m - 1) / m;