mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-05 16:51:13 +08:00
Faster Metal unary and binary for general case (#1431)
* faster unary and binary for general case * update ternary + jit fix * fix jit * unary work per thread
This commit is contained in:
parent
afc9c0ec1b
commit
4f9f9ebb6f
@ -19,14 +19,13 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 3;
|
|
||||||
|
|
||||||
std::string get_kernel_name(
|
std::string get_kernel_name(
|
||||||
BinaryOpType bopt,
|
BinaryOpType bopt,
|
||||||
const std::string& op,
|
const std::string& op,
|
||||||
const array& a,
|
const array& a,
|
||||||
bool use_2d,
|
bool use_2d,
|
||||||
int ndim) {
|
int ndim,
|
||||||
|
int work_per_thread) {
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
@ -43,14 +42,17 @@ std::string get_kernel_name(
|
|||||||
break;
|
break;
|
||||||
case BinaryOpType::General:
|
case BinaryOpType::General:
|
||||||
kname << "g";
|
kname << "g";
|
||||||
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
|
if (ndim <= 3) {
|
||||||
kname << ndim;
|
kname << ndim;
|
||||||
} else {
|
} else {
|
||||||
kname << "n";
|
kname << "n";
|
||||||
|
if (work_per_thread > 1) {
|
||||||
|
kname << work_per_thread;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
kname << op << type_to_name(a);
|
kname << "_" << op << type_to_name(a);
|
||||||
return kname.str();
|
return kname.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,7 +87,11 @@ void binary_op_gpu_inplace(
|
|||||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||||
|
|
||||||
bool use_2d = out.data_size() > UINT32_MAX;
|
bool use_2d = out.data_size() > UINT32_MAX;
|
||||||
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
auto ndim = shape.size();
|
||||||
|
int work_per_thread =
|
||||||
|
(bopt == BinaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
|
||||||
|
std::string kernel_name =
|
||||||
|
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
auto kernel = outputs.size() == 2
|
auto kernel = outputs.size() == 2
|
||||||
@ -110,7 +116,11 @@ void binary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (bopt == BinaryOpType::General) {
|
if (bopt == BinaryOpType::General) {
|
||||||
auto ndim = shape.size();
|
// Launch up to 3D grid of threads
|
||||||
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
|
size_t rest = out.size() / (dim0 * dim1);
|
||||||
|
|
||||||
if (ndim > 3) {
|
if (ndim > 3) {
|
||||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
|
||||||
compute_encoder->setBytes(
|
compute_encoder->setBytes(
|
||||||
@ -118,6 +128,7 @@ void binary_op_gpu_inplace(
|
|||||||
compute_encoder->setBytes(
|
compute_encoder->setBytes(
|
||||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
|
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
|
||||||
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
} else {
|
} else {
|
||||||
// The shape is implicit in the grid for <= 3D
|
// The shape is implicit in the grid for <= 3D
|
||||||
compute_encoder->setBytes(
|
compute_encoder->setBytes(
|
||||||
@ -126,10 +137,6 @@ void binary_op_gpu_inplace(
|
|||||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch up to 3D grid of threads
|
|
||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
|
||||||
size_t rest = out.size() / (dim0 * dim1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size != 1024) {
|
if (thread_group_size != 1024) {
|
||||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||||
|
@ -42,18 +42,19 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name);
|
||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
auto u_def = get_template_definition(
|
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
||||||
"v" + lib_name, "unary_v", get_type_string(out_type), op);
|
kernel_source << get_template_definition(
|
||||||
auto u2_def = get_template_definition(
|
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||||
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
|
kernel_source << get_template_definition(
|
||||||
auto g_def = get_template_definition(
|
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||||
"g" + lib_name, "unary_g", get_type_string(out_type), op);
|
kernel_source << get_template_definition(
|
||||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||||
<< u_def << u2_def << g_def;
|
kernel_source << get_template_definition(
|
||||||
|
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
}
|
}
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@ -81,13 +82,20 @@ void add_binary_kernels(
|
|||||||
for (auto& [name, func] : kernel_types) {
|
for (auto& [name, func] : kernel_types) {
|
||||||
std::string template_def;
|
std::string template_def;
|
||||||
template_def = get_template_definition(
|
template_def = get_template_definition(
|
||||||
name + lib_name,
|
name + "_" + lib_name,
|
||||||
func,
|
func,
|
||||||
get_type_string(in_type),
|
get_type_string(in_type),
|
||||||
get_type_string(out_type),
|
get_type_string(out_type),
|
||||||
op);
|
op);
|
||||||
kernel_source << template_def;
|
kernel_source << template_def;
|
||||||
}
|
}
|
||||||
|
kernel_source << get_template_definition(
|
||||||
|
"gn4_" + lib_name,
|
||||||
|
"binary_g",
|
||||||
|
get_type_string(in_type),
|
||||||
|
get_type_string(out_type),
|
||||||
|
op,
|
||||||
|
4);
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_binary_kernel(
|
MTL::ComputePipelineState* get_binary_kernel(
|
||||||
@ -96,7 +104,7 @@ MTL::ComputePipelineState* get_binary_kernel(
|
|||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(2);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name);
|
||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
@ -113,7 +121,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
|||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(2);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name);
|
||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
@ -149,6 +157,8 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
name + "_" + lib_name, func, get_type_string(type), op);
|
name + "_" + lib_name, func, get_type_string(type), op);
|
||||||
kernel_source << template_def;
|
kernel_source << template_def;
|
||||||
}
|
}
|
||||||
|
kernel_source << get_template_definition(
|
||||||
|
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
}
|
}
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
@ -113,7 +113,7 @@ template <typename T, typename U, typename Op>
|
|||||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = 1>
|
||||||
[[kernel]] void binary_g(
|
[[kernel]] void binary_g(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -124,8 +124,16 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const int& ndim,
|
constant const int& ndim,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
auto idx = elem_to_loc_2_nd(
|
||||||
|
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||||
|
auto xshape = shape[ndim - 1];
|
||||||
size_t out_idx =
|
size_t out_idx =
|
||||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
auto a_xstride = a_strides[ndim - 1];
|
||||||
|
auto b_xstride = b_strides[ndim - 1];
|
||||||
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
|
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
|
||||||
|
idx.x += a_xstride;
|
||||||
|
idx.y += b_xstride;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,18 +9,19 @@
|
|||||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/binary.h"
|
#include "mlx/backend/metal/kernels/binary.h"
|
||||||
|
|
||||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||||
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
|
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||||
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||||
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||||
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||||
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
|
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
|
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
|
||||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
|
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||||
|
|
||||||
#define instantiate_binary_integer(op) \
|
#define instantiate_binary_integer(op) \
|
||||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||||
|
@ -143,7 +143,7 @@ template <typename T, typename U, typename Op>
|
|||||||
d[out_idx] = out[1];
|
d[out_idx] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = 1>
|
||||||
[[kernel]] void binary_g(
|
[[kernel]] void binary_g(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -155,10 +155,18 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const int& ndim,
|
constant const int& ndim,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
auto idx = elem_to_loc_2_nd(
|
||||||
|
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||||
|
auto xshape = shape[ndim - 1];
|
||||||
size_t out_idx =
|
size_t out_idx =
|
||||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
auto out = Op()(a[idx.x], b[idx.y]);
|
auto a_xstride = a_strides[ndim - 1];
|
||||||
c[out_idx] = out[0];
|
auto b_xstride = b_strides[ndim - 1];
|
||||||
d[out_idx] = out[1];
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
|
auto out = Op()(a[idx.x], b[idx.y]);
|
||||||
|
c[out_idx] = out[0];
|
||||||
|
d[out_idx++] = out[1];
|
||||||
|
idx.x += a_xstride;
|
||||||
|
idx.y += b_xstride;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,18 +7,19 @@
|
|||||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||||
|
|
||||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||||
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
|
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||||
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||||
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||||
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||||
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
|
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
|
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
|
||||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
|
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||||
|
|
||||||
#define instantiate_binary_float(op) \
|
#define instantiate_binary_float(op) \
|
||||||
instantiate_binary_all(op, float16, half, half) \
|
instantiate_binary_all(op, float16, half, half) \
|
||||||
|
@ -75,7 +75,7 @@ template <typename T, typename Op>
|
|||||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op, int N = 1>
|
||||||
[[kernel]] void ternary_g(
|
[[kernel]] void ternary_g(
|
||||||
device const bool* a,
|
device const bool* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -88,9 +88,23 @@ template <typename T, typename Op>
|
|||||||
constant const int& ndim,
|
constant const int& ndim,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx =
|
auto idx = elem_to_loc_3_nd(
|
||||||
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
{N * index.x, index.y, index.z},
|
||||||
|
shape,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
c_strides,
|
||||||
|
ndim);
|
||||||
|
auto xshape = shape[ndim - 1];
|
||||||
size_t out_idx =
|
size_t out_idx =
|
||||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
auto a_xstride = a_strides[ndim - 1];
|
||||||
|
auto b_xstride = b_strides[ndim - 1];
|
||||||
|
auto c_xstride = c_strides[ndim - 1];
|
||||||
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
|
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||||
|
idx.x += a_xstride;
|
||||||
|
idx.y += b_xstride;
|
||||||
|
idx.z += c_xstride;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||||
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
|
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
|
||||||
|
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
|
||||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
||||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
|
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
|
||||||
|
@ -18,14 +18,23 @@ template <typename T, typename Op>
|
|||||||
out[offset] = Op()(in[offset]);
|
out[offset] = Op()(in[offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op, int N = 1>
|
||||||
[[kernel]] void unary_g(
|
[[kernel]] void unary_g(
|
||||||
device const T* in,
|
device const T* in,
|
||||||
device T* out,
|
device T* out,
|
||||||
constant const int* in_shape,
|
constant const int* in_shape,
|
||||||
constant const size_t* in_strides,
|
constant const size_t* in_strides,
|
||||||
device const int& ndim,
|
device const int& ndim,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint3 index [[thread_position_in_grid]],
|
||||||
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
out[index] = Op()(in[idx]);
|
auto idx =
|
||||||
|
elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
||||||
|
auto xshape = in_shape[ndim - 1];
|
||||||
|
auto xstride = in_strides[ndim - 1];
|
||||||
|
size_t out_idx =
|
||||||
|
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
|
out[out_idx++] = Op()(in[idx]);
|
||||||
|
idx += xstride;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,10 +5,11 @@
|
|||||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/unary.h"
|
#include "mlx/backend/metal/kernels/unary.h"
|
||||||
|
|
||||||
#define instantiate_unary_all(op, tname, type) \
|
#define instantiate_unary_all(op, tname, type) \
|
||||||
instantiate_kernel("v" #op #tname, unary_v, type, op) \
|
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
|
||||||
instantiate_kernel("v2" #op #tname, unary_v2, type, op) \
|
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
|
||||||
instantiate_kernel("g" #op #tname, unary_g, type, op)
|
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4) \
|
||||||
|
instantiate_kernel("g_" #op #tname, unary_g, type, op)
|
||||||
|
|
||||||
#define instantiate_unary_float(op) \
|
#define instantiate_unary_float(op) \
|
||||||
instantiate_unary_all(op, float16, half) \
|
instantiate_unary_all(op, float16, half) \
|
||||||
|
@ -8,8 +8,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 3;
|
|
||||||
|
|
||||||
void ternary_op_gpu_inplace(
|
void ternary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
@ -43,13 +41,18 @@ void ternary_op_gpu_inplace(
|
|||||||
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
||||||
|
|
||||||
bool use_2d = out.data_size() > UINT_MAX;
|
bool use_2d = out.data_size() > UINT_MAX;
|
||||||
|
auto ndim = shape.size();
|
||||||
|
int work_per_thread =
|
||||||
|
(topt == TernaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
{
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
if (topt == TernaryOpType::General) {
|
if (topt == TernaryOpType::General) {
|
||||||
kname << "g";
|
kname << "g";
|
||||||
if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) {
|
if (shape.size() <= 3) {
|
||||||
kname << shape.size();
|
kname << shape.size();
|
||||||
|
} else if (work_per_thread > 1) {
|
||||||
|
kname << "n" << work_per_thread;
|
||||||
}
|
}
|
||||||
} else if (use_2d) {
|
} else if (use_2d) {
|
||||||
kname << "v2";
|
kname << "v2";
|
||||||
@ -75,16 +78,19 @@ void ternary_op_gpu_inplace(
|
|||||||
compute_encoder.set_output_array(out, 3);
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
|
||||||
if (topt == TernaryOpType::General) {
|
if (topt == TernaryOpType::General) {
|
||||||
auto ndim = shape.size();
|
// Launch up to 3D grid of threads
|
||||||
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
|
size_t rest = out.size() / (dim0 * dim1);
|
||||||
|
|
||||||
if (ndim > 3) {
|
if (ndim > 3) {
|
||||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
||||||
|
|
||||||
if (ndim > MAX_TERNARY_SPECIALIZED_DIMS) {
|
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// The shape is implicit in the grid for <= 3D
|
// The shape is implicit in the grid for <= 3D
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||||
@ -92,10 +98,6 @@ void ternary_op_gpu_inplace(
|
|||||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch up to 3D grid of threads
|
|
||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
|
||||||
size_t rest = out.size() / (dim0 * dim1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size != 1024) {
|
if (thread_group_size != 1024) {
|
||||||
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
|
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
@ -25,33 +26,60 @@ void unary_op_gpu_inplace(
|
|||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto maybe_collapse = [contig, &in, &out]() {
|
||||||
|
if (!contig) {
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(
|
||||||
|
{in, out},
|
||||||
|
/* size_cap = */ INT32_MAX);
|
||||||
|
return std::make_pair(shape, strides[0]);
|
||||||
|
} else {
|
||||||
|
return std::make_pair(std::vector<int>{}, std::vector<size_t>{});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto [shape, strides] = maybe_collapse();
|
||||||
|
int ndim = shape.size();
|
||||||
|
int work_per_thread = (!contig && shape[ndim - 1] > 4) ? 4 : 1;
|
||||||
size_t nthreads = contig ? in.data_size() : in.size();
|
size_t nthreads = contig ? in.data_size() : in.size();
|
||||||
bool use_2d = nthreads > UINT32_MAX;
|
bool use_2d = nthreads > UINT32_MAX;
|
||||||
std::string kernel_name =
|
std::string kernel_name;
|
||||||
(contig ? (use_2d ? "v2" : "v") : "g") + op + type_to_name(out);
|
if (contig) {
|
||||||
|
kernel_name = (use_2d ? "v2" : "v");
|
||||||
|
} else {
|
||||||
|
kernel_name = (work_per_thread == 4 ? "gn4" : "g");
|
||||||
|
}
|
||||||
|
kernel_name += "_" + op + type_to_name(out);
|
||||||
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
|
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
|
||||||
|
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size > nthreads) {
|
|
||||||
thread_group_size = nthreads;
|
|
||||||
}
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
compute_encoder.set_input_array(
|
compute_encoder.set_input_array(
|
||||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
if (!contig) {
|
if (!contig) {
|
||||||
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
|
// Launch up to 3D grid of threads
|
||||||
compute_encoder->setBytes(
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
in.strides().data(), in.ndim() * sizeof(size_t), 3);
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
int ndim = in.ndim();
|
size_t rest = out.size() / (dim0 * dim1);
|
||||||
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(strides.data(), ndim * sizeof(size_t), 3);
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||||
|
if (thread_group_size != 1024) {
|
||||||
|
throw std::runtime_error("[Metal::unary] Must use 1024 sized block");
|
||||||
|
}
|
||||||
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
|
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||||
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
} else {
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void unary_op_gpu(
|
void unary_op_gpu(
|
||||||
|
Loading…
Reference in New Issue
Block a user