Faster indexing math in a few kernels (#1589)

* wip: faster compiled kernels

* faster general unary with uint specialization

* index type in compiled, unary, binary, ternary, copy

* fix jit

* jit fix

* specialize gather + scatter

* nit in docs
This commit is contained in:
Awni Hannun
2024-11-18 19:52:00 -08:00
committed by GitHub
parent bf481e8e5d
commit 2419edd5b2
25 changed files with 630 additions and 484 deletions

View File

@@ -36,27 +36,31 @@ void ternary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT_MAX;
bool large = out.data_size() > UINT_MAX;
auto ndim = shape.size();
int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1;
std::string kernel_name;
{
std::ostringstream kname;
if (topt == TernaryOpType::General) {
kname << "g";
if (shape.size() <= 3) {
kname << shape.size();
} else if (work_per_thread > 1) {
kname << "n" << work_per_thread;
}
} else if (use_2d) {
kname << "v2";
} else {
kname << "v";
}
kname << "_" << op << type_to_name(b);
kernel_name = kname.str();
int work_per_thread;
if (topt == TernaryOpType::General) {
work_per_thread = large ? 4 : 2;
} else {
work_per_thread = 1;
}
std::string kernel_name;
if (topt == TernaryOpType::General) {
kernel_name = "g";
if (shape.size() <= 3) {
kernel_name += std::to_string(shape.size());
} else if (work_per_thread > 1) {
concatenate(kernel_name, "n", std::to_string(work_per_thread));
}
if (large) {
kernel_name += "large";
}
} else if (large) {
kernel_name = "v2";
} else {
kernel_name = "v";
}
concatenate(kernel_name, "_", op, type_to_name(b));
auto& d = metal::device(s.device);
@@ -107,8 +111,8 @@ void ternary_op_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}