Faster indexing math in a few kernels (#1589)

* wip: faster compiled kernels

* faster general unary with uint specialization

* index type in compiled, unary, binary, ternary, copy

* fix jit

* jit fix

* specialize gather + scatter

* nit in docs
This commit is contained in:
Awni Hannun 2024-11-18 19:52:00 -08:00 committed by GitHub
parent bf481e8e5d
commit 2419edd5b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 630 additions and 484 deletions

View File

@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@ -279,7 +279,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
auto contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;

View File

@ -22,37 +22,37 @@ std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool use_2d,
bool large,
int ndim,
int work_per_thread) {
std::ostringstream kname;
std::string kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
kname = "ss";
break;
case BinaryOpType::ScalarVector:
kname << (use_2d ? "sv2" : "sv");
kname = (large ? "sv2" : "sv");
break;
case BinaryOpType::VectorScalar:
kname << (use_2d ? "vs2" : "vs");
kname = (large ? "vs2" : "vs");
break;
case BinaryOpType::VectorVector:
kname << (use_2d ? "vv2" : "vv");
kname = (large ? "vv2" : "vv");
break;
case BinaryOpType::General:
kname << "g";
kname = "g";
if (ndim <= 3) {
kname << ndim;
kname += std::to_string(ndim);
} else {
kname << "n";
if (work_per_thread > 1) {
kname << work_per_thread;
concatenate(kname, "n", std::to_string(work_per_thread));
}
if (large) {
kname += "large";
}
break;
}
kname << "_" << op << type_to_name(a);
return kname.str();
concatenate(kname, "_", op, type_to_name(a));
return kname;
}
void binary_op_gpu_inplace(
@ -81,11 +81,16 @@ void binary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT32_MAX;
bool large = out.data_size() > UINT32_MAX;
auto ndim = shape.size();
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
int work_per_thread;
if (bopt == BinaryOpType::General) {
work_per_thread = large ? 4 : 2;
} else {
work_per_thread = 1;
}
std::string kernel_name =
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
auto& d = metal::device(s.device);
auto kernel = outputs.size() == 2
@ -141,7 +146,7 @@ void binary_op_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

View File

@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream> //TODO
#include <sstream>
#include "mlx/backend/common/compiled.h"
@ -11,12 +12,12 @@
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core {
constexpr int WORK_PER_THREAD = 4;
inline void build_kernel(
std::ostream& os,
std::string& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
@ -41,8 +42,8 @@ inline void build_kernel(
int cnt = 0;
// Start the kernel
os << "[[host_name(\"" << kernel_name << "\")]]\n"
<< "[[kernel]] void " << kernel_name << "(\n";
os += fmt::format(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
// Add the input arguments
for (auto& x : inputs) {
@ -54,51 +55,61 @@ inline void build_kernel(
}
// Scalars and contiguous need no strides
if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
} else {
if (!is_scalar(x) && !contiguous) {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
}
os += fmt::format(
" device const {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
xname,
cnt++);
}
if (add_indices) {
os << " constant const size_t* in_strides [[buffer(" << cnt++
<< ")]],\n";
os += fmt::format(
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++);
}
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
os += fmt::format(
" device {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
namer.get_name(x),
cnt++);
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]],\n"
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
os += fmt::format(
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
}
if (dynamic_dims) {
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
}
// The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]],\n"
<< " uint3 grid [[threads_per_grid]]) {\n";
os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n";
if (use_big_index) {
std::string idx_type = use_big_index ? "size_t" : "uint";
if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n";
} else if (work_per_thread > 1) {
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
<< " int xshape = output_shape["
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
os += fmt::format(
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
} else {
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
os += fmt::format(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
}
// Read constant / contiguous inputs in tmps
@ -109,16 +120,19 @@ inline void build_kernel(
if (is_constant(x)) {
auto type_str = get_type_string(x.dtype());
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ");\n";
std::ostringstream ss;
print_constant(ss, x);
os += fmt::format(
" auto tmp_{0} = static_cast<{1}>({2});\n",
xname,
get_type_string(x.dtype()),
ss.str());
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];\n";
os += fmt::format(
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];\n";
os += fmt::format(
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
} else {
nc_inputs.push_back(x);
}
@ -127,83 +141,98 @@ inline void build_kernel(
// Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
<< "in_strides[" << offset << "]);\n";
os += fmt::format(
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset);
} else if (ndim == 2) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
<< "in_strides + " << offset << ");\n";
os += fmt::format(
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n",
idx_type,
offset);
} else if (ndim == 3) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
<< "in_strides + " << offset << ");\n";
os += fmt::format(
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n",
idx_type,
offset);
} else if (!dynamic_dims) {
int offset = i * ndim;
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
<< offset + ndim - 1 << "]"
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
int offset = (i + 1) * ndim;
os += fmt::format(
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
idx_type,
offset - 1,
offset - 2);
} else {
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
<< i << " + ndim - 1]"
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
os += fmt::format(
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
idx_type,
i);
}
}
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
os << " uint zpos = pos.z;\n";
os += " uint zpos = pos.z;\n";
if (dynamic_dims) {
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
} else {
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
}
os << " uint l = zpos % output_shape[d];\n";
os += " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os << " index_" << xname << " += ";
os += fmt::format(" index_{0} += ", xname);
if (dynamic_dims) {
os << "l * in_strides[" << i << " * ndim + d];\n";
os +=
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
} else {
os << "l * in_strides[" << i * ndim << " + d];\n";
os +=
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
}
}
os << " zpos /= output_shape[d];\n }\n";
os += " zpos /= output_shape[d];\n }\n";
}
// Open per-thread loop
if (work_per_thread > 1) {
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
// Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index_" << xname << "];\n";
os += fmt::format(
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
os += fmt::format(
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");\n";
os += fmt::format(
"static_cast<{0}>(tmp_{1});\n",
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
} else {
x.primitive().print(os);
os << "()(";
std::ostringstream ss;
x.primitive().print(ss);
os += ss.str();
os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";\n";
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
// Increment indices and close per thread loop
if (work_per_thread > 1) {
@ -211,18 +240,18 @@ inline void build_kernel(
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
if (!dynamic_dims) {
os << " index_" << xname << " += "
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
os += fmt::format(
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
} else {
os << " index_" << xname << " += "
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
os += fmt::format(
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
}
}
os << " index++;\n }\n";
os += " index++;\n }\n";
}
// Finish the kernel
os << "}\n";
os += "}\n";
if (cnt > 31) {
std::ostringstream msg;
@ -246,9 +275,9 @@ void Compiled::eval_gpu(
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() {
std::ostringstream kernel;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
std::string kernel = metal::utils();
concatenate(
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
@ -261,7 +290,7 @@ void Compiled::eval_gpu(
/* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_big",
kernel_lib_ + "_contiguous_large",
inputs_,
outputs_,
tape_,
@ -282,7 +311,21 @@ void Compiled::eval_gpu(
/* ndim = */ i,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
/* work_per_thread = */ i > 3 ? 2 : 1);
if (i > 1) {
build_kernel(
kernel,
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false,
/* use_big_index = */ true,
/* work_per_thread = */ i > 3 ? 4 : 1);
}
}
build_kernel(
kernel,
@ -295,13 +338,25 @@ void Compiled::eval_gpu(
/* ndim = */ 0,
/* dynamic_dims = */ true,
/* use_big_index = */ false,
/* work_per_thread = */ WORK_PER_THREAD);
return kernel.str();
/* work_per_thread = */ 2);
build_kernel(
kernel,
kernel_lib_ + "_strided_dynamic_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true,
/* use_big_index = */ true,
/* work_per_thread = */ 4);
return kernel;
});
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, output_shape);
auto contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
@ -349,13 +404,19 @@ void Compiled::eval_gpu(
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool use_2d = false;
bool large;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
use_2d = (max_size > UINT32_MAX);
large = (max_size > UINT32_MAX);
} else {
size_t max_size = 0;
for (auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
large = (max_size > UINT32_MAX);
}
// Get the kernel from the lib
@ -368,8 +429,9 @@ void Compiled::eval_gpu(
} else {
kernel_name += std::to_string(shape.size());
}
} else if (use_2d) {
kernel_name += "_big";
}
if (large) {
kernel_name += "_large";
}
auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
@ -422,7 +484,7 @@ void Compiled::eval_gpu(
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = use_2d
MTL::Size grid_dims = large
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
@ -430,7 +492,7 @@ void Compiled::eval_gpu(
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2;

View File

@ -74,40 +74,42 @@ void copy_gpu_inplace(
};
auto [shape, strides_in_, strides_out_] = maybe_collapse();
int ndim = shape.size();
bool use_2d = out.data_size() > UINT32_MAX;
bool large;
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
// Allow for negative strides
large = out.data_size() > INT32_MAX;
} else {
large = out.data_size() > UINT32_MAX;
}
auto& d = metal::device(s.device);
int work_per_thread = 1;
std::string kernel_name;
{
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << (use_2d ? "s2" : "s");
kernel_name = (large ? "s2" : "s");
break;
case CopyType::Vector:
kname << (use_2d ? "v2" : "v");
kernel_name = (large ? "v2" : "v");
break;
case CopyType::General:
kname << "g";
kernel_name = "g";
break;
case CopyType::GeneralGeneral:
kname << "gg";
kernel_name = "gg";
break;
}
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
kernel_name += std::to_string(shape.size());
} else {
work_per_thread = 4;
kname << "n4";
work_per_thread = large ? 4 : 2;
concatenate(kernel_name, "n", std::to_string(work_per_thread));
}
if (large) {
kernel_name += "large";
}
}
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index);
@ -159,7 +161,7 @@ void copy_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
@ -193,9 +195,9 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool use_2d = out.data_size() > UINT32_MAX;
bool large = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out);
auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index);
@ -210,7 +212,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

View File

@ -53,27 +53,31 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::string lib_name;
std::string kernel_name;
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
bool large_src = src.size() > UINT32_MAX;
bool large_out = out.size() > UINT32_MAX;
bool large = large_index || large_src || large_out;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
{
std::ostringstream kname;
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
<< "_" << idx_ndim;
lib_name = kname.str();
kernel_name = lib_name;
}
std::string kernel_name = fmt::format(
"gather{0}{1}_{2}_{3}_{4}",
type_to_name(out),
idx_type_name,
nidx,
idx_ndim,
large ? "size_t" : "uint");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string kernel_source = metal::utils();
kernel_source += metal::gather();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations
kernel_source << fmt::format(
kernel_source += fmt::format(
gather_kernels,
type_to_name(out) + idx_type_name,
out_type_str,
@ -81,8 +85,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx,
idx_args,
idx_arr,
idx_ndim);
return kernel_source.str();
idx_ndim,
large ? "size_t" : "uint");
return kernel_source;
});
auto& compute_encoder = d.get_command_encoder(s.index);
@ -209,8 +214,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nwork = 32;
}
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name;
switch (reduce_type_) {
@ -231,18 +234,24 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
auto upd_contig = upd.flags().row_contiguous;
{
std::ostringstream kname;
kname << "scatter" << type_to_name(out) << idx_type_name;
kname << "_" << op_name << "_" << nidx << "_"
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
lib_name = kname.str();
kernel_name = kname.str();
}
bool large_out = out.size() > UINT32_MAX;
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
bool large_upd = upd.size() > UINT32_MAX;
bool large = large_out || large_idx || large_upd;
std::string kernel_name = fmt::format(
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
type_to_name(out),
idx_type_name,
op_name,
nidx,
upd_contig ? "updc_true" : "updc_false",
nwork,
large ? "size_t" : "uint");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
@ -270,7 +279,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source << fmt::format(
kernel_source += fmt::format(
scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str,
@ -280,8 +289,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args,
idx_arr,
upd_contig,
nwork);
return kernel_source.str();
nwork,
large ? "size_t" : "uint");
return kernel_source;
});
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}(
[[kernel]] void gather{0}_{3}_{6}_{7}(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
@ -19,7 +19,7 @@ constexpr std::string_view gather_kernels = R"(
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>(
return gather_impl<{1}, {2}, {3}, {6}, {7}>(
src,
out,
src_shape,
@ -34,7 +34,7 @@ constexpr std::string_view gather_kernels = R"(
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
@ -54,7 +54,7 @@ constexpr std::string_view scatter_kernels = R"(
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
updates,
out,
upd_shape,

View File

@ -46,25 +46,27 @@ MTL::ComputePipelineState* get_unary_kernel(
auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source << get_template_definition(
"v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source.str();
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
void add_binary_kernels(
void append_binary_kernels(
const std::string lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
@ -74,26 +76,24 @@ void add_binary_kernels(
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"g2large", "binary_g_nd2"},
{"g3large", "binary_g_nd3"},
}};
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
for (auto& [name, func] : kernel_types) {
std::string template_def;
template_def = get_template_definition(
name + "_" + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
}
kernel_source << get_template_definition(
"gn4_" + lib_name,
"binary_g",
get_type_string(in_type),
get_type_string(out_type),
op,
4);
kernel_source += get_template_definition(
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
kernel_source += get_template_definition(
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
}
MTL::ComputePipelineState* get_binary_kernel(
@ -104,10 +104,11 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source.str();
std::string kernel_source;
kernel_source = metal::utils();
concatenate(kernel_source, metal::binary_ops(), metal::binary());
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@ -120,11 +121,10 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source.str();
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::binary_ops(), metal::binary_two());
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@ -136,24 +136,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
}};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto& [name, func] : kernel_types) {
std::string template_def;
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
kernel_source << get_template_definition(
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
return kernel_source.str();
kernel_source += get_template_definition(
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
kernel_source += get_template_definition(
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@ -165,31 +170,43 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
std::string kernel_source = metal::utils();
kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source << metal::utils() << metal::copy()
<< get_template_definition(
"s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition(
"v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source.str();
kernel_source +=
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
kernel_source += get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
kernel_source += get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
kernel_source += get_template_definition(
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
kernel_source += get_template_definition(
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
kernel_source += get_template_definition(
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
kernel_source += get_template_definition(
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}

View File

@ -77,12 +77,12 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
@ -91,13 +91,13 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
@ -106,14 +106,18 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@ -124,13 +128,12 @@ template <typename T, typename U, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;

View File

@ -17,10 +17,13 @@
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@ -99,14 +99,14 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
@ -116,15 +116,15 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
@ -134,16 +134,20 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@ -155,13 +159,12 @@ template <typename T, typename U, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];

View File

@ -15,10 +15,13 @@
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \

View File

@ -42,36 +42,36 @@ template <typename T, typename U>
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
IdxT dst_idx =
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1>
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@ -80,17 +80,16 @@ template <typename T, typename U, int N = 1>
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(
auto src_idx = elem_to_loc<int64_t, IdxT>(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
int64_t dst_idx =
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
IdxT dst_idx =
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
return;
}
auto xshape = src_shape[ndim - 1];
int64_t dst_idx =
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
@ -105,36 +104,36 @@ template <typename T, typename U>
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1>
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@ -143,7 +142,7 @@ template <typename T, typename U, int N = 1>
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
{N * index.x, index.y, index.z},
src_shape,
src_strides,
@ -153,8 +152,8 @@ template <typename T, typename U, int N = 1>
dst[idx.y] = static_cast<U>(src[idx.x]);
return;
}
auto src_xstride = src_strides[ndim - 1];
auto dst_xstride = dst_strides[ndim - 1];
IdxT src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]);

View File

@ -10,13 +10,19 @@
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@ -4,7 +4,7 @@
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
@ -16,18 +16,18 @@ METAL_FUNC void gather_impl(
const thread Indices<IdxT, NIDX>& indices,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
size_t src_idx = 0;
LocT src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
LocT idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
} else {
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
idx_loc += indices.row_contiguous[i]
? index.y
: elem_to_loc(
: elem_to_loc<size_t, LocT>(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
@ -35,17 +35,17 @@ METAL_FUNC void gather_impl(
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
}
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
auto src_offset =
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.z;
LocT out_idx = index.z;
if (IDX_NDIM == 1) {
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
} else if (IDX_NDIM >= 2) {
out_idx +=
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
}
out[out_idx] = src[src_offset + src_idx];
}

View File

@ -14,7 +14,7 @@ struct Indices {
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {

View File

@ -10,7 +10,8 @@ template <
typename Op,
int NIDX,
bool UPD_ROW_CONTIG,
int NWORK>
int NWORK,
typename LocT>
METAL_FUNC void scatter_impl(
const device T* updates,
device mlx_atomic<T>* out,
@ -28,29 +29,31 @@ METAL_FUNC void scatter_impl(
Op op;
auto ind_idx = gid.y * NWORK;
size_t out_offset = 0;
LocT out_offset = 0;
if (upd_size > 1) {
out_offset =
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
out_offset = elem_to_loc<size_t, LocT>(
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
}
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
size_t out_idx = out_offset;
LocT out_idx = out_offset;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = indices.row_contiguous[i]
? ind_idx
: elem_to_loc(
: elem_to_loc<size_t, LocT>(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
out_idx +=
static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
}
auto upd_idx = ind_idx * upd_size + gid.x;
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
if constexpr (!UPD_ROW_CONTIG) {
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
upd_idx =
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
}
op.atomic_update(out, updates[upd_idx], out_idx);
}

View File

@ -32,13 +32,13 @@ template <typename T, typename Op>
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_strides);
auto b_idx = elem_to_loc_1(index, b_strides);
auto c_idx = elem_to_loc_1(index, c_strides);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_strides);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_strides);
auto c_idx = elem_to_loc_1<size_t, uint>(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
template <typename T, typename Op, typename IdxT = size_t>
[[kernel]] void ternary_g_nd2(
device const bool* a,
device const T* b,
@ -49,14 +49,14 @@ template <typename T, typename Op>
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2(index, c_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
template <typename T, typename Op, typename IdxT = size_t>
[[kernel]] void ternary_g_nd3(
device const bool* a,
device const T* b,
@ -67,15 +67,14 @@ template <typename T, typename Op>
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, int N = 1>
template <typename T, typename Op, int N = 1, typename IdxT = size_t>
[[kernel]] void ternary_g(
device const bool* a,
device const T* b,
@ -88,7 +87,7 @@ template <typename T, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd(
auto idx = elem_to_loc_3_nd<IdxT>(
{N * index.x, index.y, index.z},
shape,
a_strides,
@ -96,11 +95,10 @@ template <typename T, typename Op, int N = 1>
c_strides,
ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
auto c_xstride = c_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
IdxT c_xstride = c_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
idx.x += a_xstride;

View File

@ -11,10 +11,13 @@
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \

View File

@ -18,7 +18,12 @@ template <typename T, typename U, typename Op>
out[offset] = Op()(in[offset]);
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void unary_g(
device const T* in,
device U* out,
@ -27,12 +32,11 @@ template <typename T, typename U, typename Op, int N = 1>
device const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
auto idx = elem_to_loc<size_t, IdxT>(
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
auto xshape = in_shape[ndim - 1];
auto xstride = in_strides[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
IdxT xstride = in_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
out[out_idx++] = Op()(in[idx]);
idx += xstride;

View File

@ -8,8 +8,10 @@
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
instantiate_kernel( \
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)

View File

@ -89,44 +89,45 @@ struct Limits<complex64_t> {
///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
uint elem,
constant const int* shape,
constant const stride_t* strides,
constant const StrideT* strides,
int ndim) {
stride_t loc = 0;
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
StrideT elem,
constant const int* shape,
constant const stride_t* strides,
constant const StrideT* strides,
int ndim) {
stride_t loc = 0;
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
// Non templated version to handle arbitrary dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
uint3 elem,
constant const int* shape,
constant const stride_t* strides,
constant const StrideT* strides,
int ndim) {
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
IdxT loc =
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
loc += (elem.z % shape[d]) * IdxT(strides[d]);
elem.z /= shape[d];
}
return loc;
@ -135,61 +136,65 @@ METAL_FUNC stride_t elem_to_loc(
///////////////////////////////////////////////////////////////////////////////
// Single Array with fixed N dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
return elem * stride;
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
return elem * IdxT(stride);
}
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
}
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
elem.z * IdxT(strides[0]);
}
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims
template <typename stride_t>
METAL_FUNC ulong2 elem_to_loc_2_nd(
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const stride_t* a_strides,
constant const stride_t* b_strides,
constant const StrideT* a_strides,
constant const StrideT* b_strides,
int ndim) {
ulong2 loc = {
ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
vec<IdxT, 2> loc = {
IdxT(
elem.x * IdxT(a_strides[ndim - 1]) +
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
IdxT(
elem.x * IdxT(b_strides[ndim - 1]) +
elem.y * IdxT(b_strides[ndim - 2]))};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.x += l * IdxT(a_strides[d]);
loc.y += l * IdxT(b_strides[d]);
elem.z /= shape[d];
}
return loc;
}
METAL_FUNC ulong3 elem_to_loc_3_nd(
template <typename IdxT = size_t>
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
int ndim) {
ulong3 loc = {
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
vec<IdxT, 3> loc = {
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
loc.x += l * IdxT(a_strides[d]);
loc.y += l * IdxT(b_strides[d]);
loc.z += l * IdxT(c_strides[d]);
elem.z /= shape[d];
}
return loc;

View File

@ -36,27 +36,31 @@ void ternary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT_MAX;
bool large = out.data_size() > UINT_MAX;
auto ndim = shape.size();
int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1;
std::string kernel_name;
{
std::ostringstream kname;
int work_per_thread;
if (topt == TernaryOpType::General) {
kname << "g";
if (shape.size() <= 3) {
kname << shape.size();
} else if (work_per_thread > 1) {
kname << "n" << work_per_thread;
}
} else if (use_2d) {
kname << "v2";
work_per_thread = large ? 4 : 2;
} else {
kname << "v";
work_per_thread = 1;
}
kname << "_" << op << type_to_name(b);
kernel_name = kname.str();
std::string kernel_name;
if (topt == TernaryOpType::General) {
kernel_name = "g";
if (shape.size() <= 3) {
kernel_name += std::to_string(shape.size());
} else if (work_per_thread > 1) {
concatenate(kernel_name, "n", std::to_string(work_per_thread));
}
if (large) {
kernel_name += "large";
}
} else if (large) {
kernel_name = "v2";
} else {
kernel_name = "v";
}
concatenate(kernel_name, "_", op, type_to_name(b));
auto& d = metal::device(s.device);
@ -107,7 +111,7 @@ void ternary_op_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

View File

@ -35,16 +35,19 @@ void unary_op_gpu_inplace(
};
auto [shape, strides] = maybe_collapse();
int ndim = shape.size();
int work_per_thread = !contig ? 4 : 1;
size_t nthreads = contig ? in.data_size() : in.size();
bool use_2d = nthreads > UINT32_MAX;
bool large = nthreads > UINT32_MAX;
int work_per_thread = !contig && large ? 4 : 1;
std::string kernel_name;
if (contig) {
kernel_name = (use_2d ? "v2" : "v");
kernel_name = (large ? "v2" : "v");
} else {
kernel_name = (work_per_thread == 4 ? "gn4" : "g");
kernel_name = "gn" + std::to_string(work_per_thread);
if (large) {
kernel_name += "_large";
}
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
}
concatenate(kernel_name, "_", op, type_to_name(in), type_to_name(out));
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@ -73,7 +76,7 @@ void unary_op_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

View File

@ -61,4 +61,15 @@ inline void debug_set_primitive_buffer_label(
std::string get_primitive_string(Primitive* primitive);
template <typename T>
void concatenate(std::string& acc, T first) {
acc += first;
}
template <typename T, typename... Args>
void concatenate(std::string& acc, T first, Args... args) {
acc += first;
concatenate(acc, args...);
}
} // namespace mlx::core