mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-06 02:38:18 +08:00
Faster triu, tril, where with scalar (#2644)
This commit is contained in:
@@ -11,6 +11,8 @@ namespace mlx::core {
|
||||
enum class TernaryOpType {
|
||||
ScalarScalarScalar,
|
||||
VectorVectorVector,
|
||||
VectorVectorScalar,
|
||||
VectorScalarVector,
|
||||
General,
|
||||
};
|
||||
|
||||
@@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
||||
c.flags().col_contiguous)) {
|
||||
topt = TernaryOpType::VectorVectorVector;
|
||||
} else if (
|
||||
b.data_size() == 1 && a.flags().row_contiguous &&
|
||||
c.flags().row_contiguous) {
|
||||
topt = TernaryOpType::VectorScalarVector;
|
||||
} else if (
|
||||
c.data_size() == 1 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
topt = TernaryOpType::VectorVectorScalar;
|
||||
} else {
|
||||
topt = TernaryOpType::General;
|
||||
}
|
||||
@@ -59,6 +69,8 @@ inline void set_ternary_op_output_data(
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case TernaryOpType::VectorVectorScalar:
|
||||
case TernaryOpType::VectorScalarVector:
|
||||
case TernaryOpType::General:
|
||||
// Try to donate an input which is row_contiguous
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
|
||||
@@ -156,7 +156,25 @@ void ternary_op_gpu_inplace(
|
||||
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
if (topt == TernaryOpType::General) {
|
||||
if (topt == TernaryOpType::VectorVectorVector ||
|
||||
topt == TernaryOpType::ScalarScalarScalar) {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
constexpr int N_READS = 16 / sizeof(DType);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||
encoder.add_kernel_node(
|
||||
cu::ternary_v<Op, DType, IdxT, N_READS>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size());
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
@@ -225,23 +243,6 @@ void ternary_op_gpu_inplace(
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
constexpr int N_READS = 16 / sizeof(DType);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||
encoder.add_kernel_node(
|
||||
cu::ternary_v<Op, DType, IdxT, N_READS>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -144,8 +144,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
auto t_str = get_type_string(type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
|
||||
{"v2", "ternary_v2"},
|
||||
const std::array<std::pair<std::string, std::string>, 3> kernel_types = {{
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
{"g3large", "ternary_g_nd3"},
|
||||
@@ -154,13 +153,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
|
||||
kernel_source += get_template_definition(
|
||||
"v2_" + lib_name, "ternary_v2", t_str, op, false, false);
|
||||
kernel_source += get_template_definition(
|
||||
"sv2_" + lib_name, "ternary_v2", t_str, op, true, false);
|
||||
kernel_source += get_template_definition(
|
||||
"vs2_" + lib_name, "ternary_v2", t_str, op, false, true);
|
||||
|
||||
if (get_work_per_thread(type) > 1) {
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
|
||||
kernel_source += get_template_definition(
|
||||
"vn_" + lib_name, "ternary_v", t_str, op, false, false);
|
||||
kernel_source += get_template_definition(
|
||||
"svn_" + lib_name, "ternary_v", t_str, op, true, false);
|
||||
kernel_source += get_template_definition(
|
||||
"vsn_" + lib_name, "ternary_v", t_str, op, false, true);
|
||||
}
|
||||
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"v_" + lib_name, "ternary_v", t_str, op, false, false, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"sv_" + lib_name, "ternary_v", t_str, op, true, false, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"vs_" + lib_name, "ternary_v", t_str, op, false, true, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
template <
|
||||
typename T,
|
||||
typename Op,
|
||||
bool BSCALAR,
|
||||
bool CSCALAR,
|
||||
int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void ternary_v(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
@@ -11,16 +16,25 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
index *= N;
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
auto bidx = BSCALAR ? 0 : index + i;
|
||||
auto cidx = CSCALAR ? 0 : index + i;
|
||||
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
auto bidx = BSCALAR ? 0 : index + i;
|
||||
auto cidx = CSCALAR ? 0 : index + i;
|
||||
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
template <
|
||||
typename T,
|
||||
typename Op,
|
||||
bool BSCALAR,
|
||||
bool CSCALAR,
|
||||
int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void ternary_v2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
@@ -32,11 +46,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
auto bidx = BSCALAR ? 0 : offset + i;
|
||||
auto cidx = CSCALAR ? 0 : offset + i;
|
||||
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
auto bidx = BSCALAR ? 0 : offset + i;
|
||||
auto cidx = CSCALAR ? 0 : offset + i;
|
||||
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,12 @@
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
#define instantiate_ternary_base(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op, false, false, 1) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op, false, false) \
|
||||
instantiate_kernel("vs_" #op #tname, ternary_v, type, op, false, true, 1) \
|
||||
instantiate_kernel("vs2_" #op #tname, ternary_v2, type, op, false, true) \
|
||||
instantiate_kernel("sv_" #op #tname, ternary_v, type, op, true, false, 1) \
|
||||
instantiate_kernel("sv2_" #op #tname, ternary_v2, type, op, true, false) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
|
||||
@@ -21,7 +25,9 @@
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("vn_" #op #tname, ternary_v, type, op, false, false) \
|
||||
instantiate_kernel("vsn_" #op #tname, ternary_v, type, op, false, true) \
|
||||
instantiate_kernel("svn_" #op #tname, ternary_v, type, op, true, false) \
|
||||
instantiate_ternary_base(op, tname, type)
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
|
||||
@@ -58,12 +58,19 @@ void ternary_op_gpu_inplace(
|
||||
if (large) {
|
||||
kernel_name += "large";
|
||||
}
|
||||
} else if (large) {
|
||||
kernel_name = "v2";
|
||||
} else if (work_per_thread > 1) {
|
||||
kernel_name = "vn";
|
||||
} else {
|
||||
kernel_name = "v";
|
||||
if (topt == TernaryOpType::VectorScalarVector) {
|
||||
kernel_name = "sv";
|
||||
} else if (topt == TernaryOpType::VectorVectorScalar) {
|
||||
kernel_name = "vs";
|
||||
} else {
|
||||
kernel_name = "v";
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "2";
|
||||
} else if (work_per_thread > 1) {
|
||||
kernel_name += "n";
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_", op, type_to_name(b));
|
||||
|
||||
|
||||
@@ -348,16 +348,16 @@ array tril(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() < 2) {
|
||||
throw std::invalid_argument("[tril] array must be at least 2-D");
|
||||
}
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
|
||||
return where(mask, x, zeros_like(x, s), s);
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k, bool_, s);
|
||||
return where(mask, x, array(0, x.dtype()), s);
|
||||
}
|
||||
|
||||
array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() < 2) {
|
||||
throw std::invalid_argument("[triu] array must be at least 2-D");
|
||||
}
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
|
||||
return where(mask, zeros_like(x, s), x, s);
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, bool_, s);
|
||||
return where(mask, array(0, x.dtype()), x, s);
|
||||
}
|
||||
|
||||
array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
|
||||
|
||||
Reference in New Issue
Block a user