mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 05:41:14 +08:00
improve metal elementwise kernels
This commit is contained in:
parent
a5ac9244c4
commit
d0ebd18d7d
@ -31,13 +31,13 @@ std::string get_kernel_name(
|
||||
kname = "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname = (large ? "sv2" : "sv");
|
||||
kname = "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname = (large ? "vs2" : "vs");
|
||||
kname = "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname = (large ? "vv2" : "vv");
|
||||
kname = "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname = "g";
|
||||
@ -51,6 +51,13 @@ std::string get_kernel_name(
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
|
||||
if (large) {
|
||||
kname += "2";
|
||||
} else if (work_per_thread > 1) {
|
||||
kname += "n";
|
||||
}
|
||||
}
|
||||
concatenate(kname, "_", op, type_to_name(a));
|
||||
return kname;
|
||||
}
|
||||
@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = get_work_per_thread(a.dtype());
|
||||
work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
|
||||
}
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
|
@ -41,7 +41,9 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
@ -61,9 +63,9 @@ void append_binary_kernels(
|
||||
std::string& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
{"vv", "binary_vv"},
|
||||
{"vsn", "binary_vs"},
|
||||
{"svn", "binary_sv"},
|
||||
{"vvn", "binary_vv"},
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
@ -78,6 +80,12 @@ void append_binary_kernels(
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@ -134,7 +142,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"vn", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
@ -144,6 +152,8 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
|
@ -17,8 +17,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,8 +36,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,8 +55,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,8 +75,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,8 +95,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,8 +115,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,11 +9,16 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
|
||||
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
|
||||
|
||||
#define instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
@ -26,15 +31,19 @@
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_all(op, int64, int64_t, int64_t)
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_binary_work_per_thread(op, tname, itype, otype)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_base(op, int64, int64_t, int64_t)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
@ -44,7 +53,7 @@
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
||||
instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
|
||||
instantiate_binary_float(op)
|
||||
|
||||
#define instantiate_binary_types_bool(op) \
|
||||
@ -52,15 +61,15 @@
|
||||
instantiate_binary_all(op, uint8, uint8_t, bool) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, bool) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, bool) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, bool) \
|
||||
instantiate_binary_base(op, uint64, uint64_t, bool) \
|
||||
instantiate_binary_all(op, int8, int8_t, bool) \
|
||||
instantiate_binary_all(op, int16, int16_t, bool) \
|
||||
instantiate_binary_all(op, int32, int32_t, bool) \
|
||||
instantiate_binary_all(op, int64, int64_t, bool) \
|
||||
instantiate_binary_base(op, int64, int64_t, bool) \
|
||||
instantiate_binary_all(op, float16, half, bool) \
|
||||
instantiate_binary_all(op, float32, float, bool) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, bool)
|
||||
instantiate_binary_base(op, complex64, complex64_t, bool)
|
||||
|
||||
instantiate_binary_types(Add)
|
||||
instantiate_binary_types(Divide)
|
||||
@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2)
|
||||
instantiate_binary_all(NaNEqual, float16, half, bool)
|
||||
instantiate_binary_all(NaNEqual, float32, float, bool)
|
||||
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
|
||||
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
|
||||
instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
|
||||
|
||||
instantiate_binary_all(LogicalOr, bool_, bool, bool)
|
||||
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
|
||||
|
@ -21,10 +21,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
auto out = Op()(a[0], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[0], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,10 +45,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[0]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[0]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[index + i], b[0]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,10 +69,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[index + i], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,11 +93,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
auto out = Op()(a[0], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[0], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -86,11 +118,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[0]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[0]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[offset + i], b[0]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,11 +143,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[offset + i], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,11 +7,16 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
|
||||
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
@ -24,22 +29,26 @@
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_binary_work_per_thread(op, tname, itype, otype)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
instantiate_binary_all(op, float32, float, float) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
|
||||
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_all(op, int64, int64_t, int64_t) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_base(op, int64, int64_t, int64_t) \
|
||||
instantiate_binary_base(op, complex64, complex64_t, complex64_t) \
|
||||
instantiate_binary_float(op)
|
||||
|
||||
instantiate_binary_types(DivMod) // clang-format on
|
||||
|
@ -9,8 +9,14 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,9 +29,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,8 +8,8 @@
|
||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
#define instantiate_ternary_base(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
|
||||
@ -20,19 +20,23 @@
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_ternary_base(op, tname, type)
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
instantiate_ternary_all(op, uint8, uint8_t) \
|
||||
instantiate_ternary_all(op, uint16, uint16_t) \
|
||||
instantiate_ternary_all(op, uint32, uint32_t) \
|
||||
instantiate_ternary_all(op, uint64, uint64_t) \
|
||||
instantiate_ternary_base(op, uint64, uint64_t) \
|
||||
instantiate_ternary_all(op, int8, int8_t) \
|
||||
instantiate_ternary_all(op, int16, int16_t) \
|
||||
instantiate_ternary_all(op, int32, int32_t) \
|
||||
instantiate_ternary_all(op, int64, int64_t) \
|
||||
instantiate_ternary_base(op, int64, int64_t) \
|
||||
instantiate_ternary_all(op, float16, half) \
|
||||
instantiate_ternary_all(op, float32, float) \
|
||||
instantiate_ternary_all(op, bfloat16, bfloat16_t) \
|
||||
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on
|
||||
instantiate_ternary_base(op, complex64, complex64_t) // clang-format on
|
||||
|
||||
instantiate_ternary_types(Select)
|
||||
|
@ -7,8 +7,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,9 +25,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,31 +5,41 @@
|
||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
|
||||
instantiate_kernel( \
|
||||
#define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op)
|
||||
|
||||
#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
|
||||
instantiate_kernel( \
|
||||
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type)
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
||||
#define instantiate_unary_base_same(op, tname, type) \
|
||||
instantiate_unary_base(op, tname, tname, type, type)
|
||||
|
||||
#define instantiate_unary_float(op) \
|
||||
instantiate_unary_all_same(op, float16, half) \
|
||||
instantiate_unary_all_same(op, float32, float) \
|
||||
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_unary_int(op) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
instantiate_unary_all_same(op, uint32, uint32_t) \
|
||||
instantiate_unary_all_same(op, uint64, uint64_t) \
|
||||
instantiate_unary_all_same(op, int8, int8_t) \
|
||||
instantiate_unary_all_same(op, int16, int16_t) \
|
||||
instantiate_unary_all_same(op, int32, int32_t) \
|
||||
instantiate_unary_all_same(op, int64, int64_t)
|
||||
#define instantiate_unary_int(op) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
instantiate_unary_all_same(op, uint32, uint32_t) \
|
||||
instantiate_unary_base_same(op, uint64, uint64_t) \
|
||||
instantiate_unary_all_same(op, int8, int8_t) \
|
||||
instantiate_unary_all_same(op, int16, int16_t) \
|
||||
instantiate_unary_all_same(op, int32, int32_t) \
|
||||
instantiate_unary_base_same(op, int64, int64_t)
|
||||
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all_same(op, bool_, bool) \
|
||||
@ -68,29 +78,29 @@ instantiate_unary_float(Tanh)
|
||||
instantiate_unary_float(Round)
|
||||
instantiate_unary_int(BitwiseInvert)
|
||||
|
||||
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log1p, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Square, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Round, complex64, complex64_t)
|
||||
instantiate_unary_all(Real, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_all(Imag, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_base_same(Abs, complex64, complex64_t)
|
||||
instantiate_unary_base_same(ArcCos, complex64, complex64_t)
|
||||
instantiate_unary_base_same(ArcSin, complex64, complex64_t)
|
||||
instantiate_unary_base_same(ArcTan, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Log1p, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Sin, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Square, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Sqrt, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Rsqrt, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Tan, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Round, complex64, complex64_t)
|
||||
instantiate_unary_base(Real, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
|
||||
|
||||
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on
|
||||
|
@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > INT32_MAX;
|
||||
work_per_thread = get_work_per_thread(b.dtype());
|
||||
work_per_thread = get_work_per_thread(b.dtype(), out.data_size());
|
||||
}
|
||||
std::string kernel_name;
|
||||
if (topt == TernaryOpType::General) {
|
||||
@ -60,6 +60,8 @@ void ternary_op_gpu_inplace(
|
||||
}
|
||||
} else if (large) {
|
||||
kernel_name = "v2";
|
||||
} else if (work_per_thread > 1) {
|
||||
kernel_name = "vn";
|
||||
} else {
|
||||
kernel_name = "v";
|
||||
}
|
||||
|
@ -43,8 +43,8 @@ void unary_op_gpu_inplace(
|
||||
int work_per_thread;
|
||||
std::string kernel_name;
|
||||
if (contig) {
|
||||
work_per_thread = get_work_per_thread(in.dtype());
|
||||
kernel_name = (large ? "v2" : "v");
|
||||
work_per_thread = get_work_per_thread(in.dtype(), in.data_size());
|
||||
kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v"));
|
||||
} else {
|
||||
work_per_thread = large ? 4 : 1;
|
||||
kernel_name = "gn" + std::to_string(work_per_thread);
|
||||
|
@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) {
|
||||
inline int get_work_per_thread(Dtype dtype) {
|
||||
return std::max(1, 8 / dtype.size());
|
||||
}
|
||||
inline int get_work_per_thread(Dtype dtype, size_t size) {
|
||||
constexpr size_t wpt_threshold = 1 << 16;
|
||||
return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());
|
||||
}
|
||||
|
||||
inline size_t ceildiv(size_t n, size_t m) {
|
||||
return (n + m - 1) / m;
|
||||
|
Loading…
Reference in New Issue
Block a user