mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster indexing math in a few kernels (#1589)
* wip: faster compiled kernels * faster general unary with uint specialization * index type in compiled, unary, binary, ternary, copy * fix jit * jit fix * specialize gather + scatter * nit in docs
This commit is contained in:
@@ -36,27 +36,31 @@ void ternary_op_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
||||
|
||||
bool use_2d = out.data_size() > UINT_MAX;
|
||||
bool large = out.data_size() > UINT_MAX;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1;
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
if (topt == TernaryOpType::General) {
|
||||
kname << "g";
|
||||
if (shape.size() <= 3) {
|
||||
kname << shape.size();
|
||||
} else if (work_per_thread > 1) {
|
||||
kname << "n" << work_per_thread;
|
||||
}
|
||||
} else if (use_2d) {
|
||||
kname << "v2";
|
||||
} else {
|
||||
kname << "v";
|
||||
}
|
||||
kname << "_" << op << type_to_name(b);
|
||||
kernel_name = kname.str();
|
||||
int work_per_thread;
|
||||
if (topt == TernaryOpType::General) {
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
std::string kernel_name;
|
||||
if (topt == TernaryOpType::General) {
|
||||
kernel_name = "g";
|
||||
if (shape.size() <= 3) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
} else if (work_per_thread > 1) {
|
||||
concatenate(kernel_name, "n", std::to_string(work_per_thread));
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "large";
|
||||
}
|
||||
} else if (large) {
|
||||
kernel_name = "v2";
|
||||
} else {
|
||||
kernel_name = "v";
|
||||
}
|
||||
concatenate(kernel_name, "_", op, type_to_name(b));
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
@@ -107,8 +111,8 @@ void ternary_op_gpu_inplace(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user