mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Faster indexing math in a few kernels (#1589)
* wip: faster compiled kernels * faster general unary with uint specialization * index type in compiled, unary, binary, ternary, copy * fix jit * jit fix * specialize gather + scatter * nit in docs
This commit is contained in:
parent
bf481e8e5d
commit
2419edd5b2
@ -184,8 +184,8 @@ Let's time these two different versions:
|
|||||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||||
|
|
||||||
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||||
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||||
|
|
||||||
Of course, this operation is quite contrived. A better approach is to simply do
|
Of course, this operation is quite contrived. A better approach is to simply do
|
||||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||||
|
@ -279,7 +279,7 @@ void Compiled::eval_cpu(
|
|||||||
|
|
||||||
// Figure out which kernel we are using
|
// Figure out which kernel we are using
|
||||||
auto& shape = outputs[0].shape();
|
auto& shape = outputs[0].shape();
|
||||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||||
|
|
||||||
// Handle all broadcasting and collect function input arguments
|
// Handle all broadcasting and collect function input arguments
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
|
@ -22,37 +22,37 @@ std::string get_kernel_name(
|
|||||||
BinaryOpType bopt,
|
BinaryOpType bopt,
|
||||||
const std::string& op,
|
const std::string& op,
|
||||||
const array& a,
|
const array& a,
|
||||||
bool use_2d,
|
bool large,
|
||||||
int ndim,
|
int ndim,
|
||||||
int work_per_thread) {
|
int work_per_thread) {
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
kname << "ss";
|
kname = "ss";
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
kname << (use_2d ? "sv2" : "sv");
|
kname = (large ? "sv2" : "sv");
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorScalar:
|
case BinaryOpType::VectorScalar:
|
||||||
kname << (use_2d ? "vs2" : "vs");
|
kname = (large ? "vs2" : "vs");
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorVector:
|
case BinaryOpType::VectorVector:
|
||||||
kname << (use_2d ? "vv2" : "vv");
|
kname = (large ? "vv2" : "vv");
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::General:
|
case BinaryOpType::General:
|
||||||
kname << "g";
|
kname = "g";
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
kname << ndim;
|
kname += std::to_string(ndim);
|
||||||
} else {
|
} else {
|
||||||
kname << "n";
|
concatenate(kname, "n", std::to_string(work_per_thread));
|
||||||
if (work_per_thread > 1) {
|
|
||||||
kname << work_per_thread;
|
|
||||||
}
|
}
|
||||||
|
if (large) {
|
||||||
|
kname += "large";
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
kname << "_" << op << type_to_name(a);
|
concatenate(kname, "_", op, type_to_name(a));
|
||||||
return kname.str();
|
return kname;
|
||||||
}
|
}
|
||||||
|
|
||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
@ -81,11 +81,16 @@ void binary_op_gpu_inplace(
|
|||||||
};
|
};
|
||||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||||
|
|
||||||
bool use_2d = out.data_size() > UINT32_MAX;
|
bool large = out.data_size() > UINT32_MAX;
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
|
int work_per_thread;
|
||||||
|
if (bopt == BinaryOpType::General) {
|
||||||
|
work_per_thread = large ? 4 : 2;
|
||||||
|
} else {
|
||||||
|
work_per_thread = 1;
|
||||||
|
}
|
||||||
std::string kernel_name =
|
std::string kernel_name =
|
||||||
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
|
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
auto kernel = outputs.size() == 2
|
auto kernel = outputs.size() == 2
|
||||||
@ -141,7 +146,7 @@ void binary_op_gpu_inplace(
|
|||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <iostream> //TODO
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
@ -11,12 +12,12 @@
|
|||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
using namespace fmt::literals;
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
constexpr int WORK_PER_THREAD = 4;
|
|
||||||
|
|
||||||
inline void build_kernel(
|
inline void build_kernel(
|
||||||
std::ostream& os,
|
std::string& os,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
@ -41,8 +42,8 @@ inline void build_kernel(
|
|||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
|
||||||
// Start the kernel
|
// Start the kernel
|
||||||
os << "[[host_name(\"" << kernel_name << "\")]]\n"
|
os += fmt::format(
|
||||||
<< "[[kernel]] void " << kernel_name << "(\n";
|
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
for (auto& x : inputs) {
|
for (auto& x : inputs) {
|
||||||
@ -54,51 +55,61 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Scalars and contiguous need no strides
|
// Scalars and contiguous need no strides
|
||||||
if (is_scalar(x) || contiguous) {
|
if (!is_scalar(x) && !contiguous) {
|
||||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
|
||||||
<< " [[buffer(" << cnt++ << ")]],\n";
|
|
||||||
} else {
|
|
||||||
add_indices = true;
|
add_indices = true;
|
||||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
|
||||||
<< " [[buffer(" << cnt++ << ")]],\n";
|
|
||||||
}
|
}
|
||||||
|
os += fmt::format(
|
||||||
|
" device const {0}* {1} [[buffer({2})]],\n",
|
||||||
|
get_type_string(x.dtype()),
|
||||||
|
xname,
|
||||||
|
cnt++);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_indices) {
|
if (add_indices) {
|
||||||
os << " constant const size_t* in_strides [[buffer(" << cnt++
|
os += fmt::format(
|
||||||
<< ")]],\n";
|
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the output arguments
|
// Add the output arguments
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
os << " device " << get_type_string(x.dtype()) << "* "
|
os += fmt::format(
|
||||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
|
" device {0}* {1} [[buffer({2})]],\n",
|
||||||
|
get_type_string(x.dtype()),
|
||||||
|
namer.get_name(x),
|
||||||
|
cnt++);
|
||||||
}
|
}
|
||||||
// Add output strides and shape to extract the indices.
|
// Add output strides and shape to extract the indices.
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
os += fmt::format(
|
||||||
<< ")]],\n"
|
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
|
os += fmt::format(
|
||||||
|
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||||
}
|
}
|
||||||
if (dynamic_dims) {
|
if (dynamic_dims) {
|
||||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
|
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The thread index in the whole grid
|
// The thread index in the whole grid
|
||||||
os << " uint3 pos [[thread_position_in_grid]],\n"
|
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||||
<< " uint3 grid [[threads_per_grid]]) {\n";
|
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||||
|
|
||||||
if (use_big_index) {
|
std::string idx_type = use_big_index ? "size_t" : "uint";
|
||||||
|
if (contiguous && use_big_index) {
|
||||||
// This is only used for contiguous kernels which don't have
|
// This is only used for contiguous kernels which don't have
|
||||||
// a third grid dimension
|
// a third grid dimension
|
||||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||||
} else if (work_per_thread > 1) {
|
} else if (work_per_thread > 1) {
|
||||||
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
|
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||||
<< " int xshape = output_shape["
|
os += fmt::format(
|
||||||
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
|
" int xshape = output_shape[{0}];\n",
|
||||||
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
|
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||||
|
os += fmt::format(
|
||||||
|
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||||
|
idx_type);
|
||||||
} else {
|
} else {
|
||||||
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
|
os += fmt::format(
|
||||||
|
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||||
|
idx_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read constant / contiguous inputs in tmps
|
// Read constant / contiguous inputs in tmps
|
||||||
@ -109,16 +120,19 @@ inline void build_kernel(
|
|||||||
|
|
||||||
if (is_constant(x)) {
|
if (is_constant(x)) {
|
||||||
auto type_str = get_type_string(x.dtype());
|
auto type_str = get_type_string(x.dtype());
|
||||||
os << " auto tmp_" << xname << " = static_cast<"
|
std::ostringstream ss;
|
||||||
<< get_type_string(x.dtype()) << ">(";
|
print_constant(ss, x);
|
||||||
print_constant(os, x);
|
os += fmt::format(
|
||||||
os << ");\n";
|
" auto tmp_{0} = static_cast<{1}>({2});\n",
|
||||||
|
xname,
|
||||||
|
get_type_string(x.dtype()),
|
||||||
|
ss.str());
|
||||||
} else if (is_scalar(x)) {
|
} else if (is_scalar(x)) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
os += fmt::format(
|
||||||
<< xname << "[0];\n";
|
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
|
||||||
} else if (contiguous) {
|
} else if (contiguous) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
os += fmt::format(
|
||||||
<< xname << "[index];\n";
|
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
|
||||||
} else {
|
} else {
|
||||||
nc_inputs.push_back(x);
|
nc_inputs.push_back(x);
|
||||||
}
|
}
|
||||||
@ -127,83 +141,98 @@ inline void build_kernel(
|
|||||||
// Initialize the indices for non-contiguous inputs
|
// Initialize the indices for non-contiguous inputs
|
||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(nc_inputs[i]);
|
auto& xname = namer.get_name(nc_inputs[i]);
|
||||||
|
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||||
if (ndim == 1) {
|
if (ndim == 1) {
|
||||||
int offset = i * ndim;
|
int offset = i * ndim;
|
||||||
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
|
os += fmt::format(
|
||||||
<< "in_strides[" << offset << "]);\n";
|
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||||
} else if (ndim == 2) {
|
} else if (ndim == 2) {
|
||||||
int offset = i * ndim;
|
int offset = i * ndim;
|
||||||
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
|
os += fmt::format(
|
||||||
<< "in_strides + " << offset << ");\n";
|
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||||
|
idx_type,
|
||||||
|
offset);
|
||||||
} else if (ndim == 3) {
|
} else if (ndim == 3) {
|
||||||
int offset = i * ndim;
|
int offset = i * ndim;
|
||||||
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
|
os += fmt::format(
|
||||||
<< "in_strides + " << offset << ");\n";
|
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n",
|
||||||
|
idx_type,
|
||||||
|
offset);
|
||||||
} else if (!dynamic_dims) {
|
} else if (!dynamic_dims) {
|
||||||
int offset = i * ndim;
|
int offset = (i + 1) * ndim;
|
||||||
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
|
os += fmt::format(
|
||||||
<< offset + ndim - 1 << "]"
|
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
|
||||||
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
|
idx_type,
|
||||||
|
offset - 1,
|
||||||
|
offset - 2);
|
||||||
} else {
|
} else {
|
||||||
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
|
os += fmt::format(
|
||||||
<< i << " + ndim - 1]"
|
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
|
||||||
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
|
idx_type,
|
||||||
|
i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
|
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
|
||||||
os << " uint zpos = pos.z;\n";
|
os += " uint zpos = pos.z;\n";
|
||||||
if (dynamic_dims) {
|
if (dynamic_dims) {
|
||||||
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
|
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||||
} else {
|
} else {
|
||||||
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
|
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||||
}
|
}
|
||||||
os << " uint l = zpos % output_shape[d];\n";
|
os += " uint l = zpos % output_shape[d];\n";
|
||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(nc_inputs[i]);
|
auto& xname = namer.get_name(nc_inputs[i]);
|
||||||
os << " index_" << xname << " += ";
|
os += fmt::format(" index_{0} += ", xname);
|
||||||
if (dynamic_dims) {
|
if (dynamic_dims) {
|
||||||
os << "l * in_strides[" << i << " * ndim + d];\n";
|
os +=
|
||||||
|
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
|
||||||
} else {
|
} else {
|
||||||
os << "l * in_strides[" << i * ndim << " + d];\n";
|
os +=
|
||||||
|
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
os << " zpos /= output_shape[d];\n }\n";
|
os += " zpos /= output_shape[d];\n }\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open per-thread loop
|
// Open per-thread loop
|
||||||
if (work_per_thread > 1) {
|
if (work_per_thread > 1) {
|
||||||
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
os +=
|
||||||
|
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read non-contiguous inputs into tmps
|
// Read non-contiguous inputs into tmps
|
||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
auto& x = nc_inputs[i];
|
auto& x = nc_inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
os += fmt::format(
|
||||||
<< xname << "[index_" << xname << "];\n";
|
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Actually write the computation
|
// Actually write the computation
|
||||||
for (auto& x : tape) {
|
for (auto& x : tape) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
os += fmt::format(
|
||||||
<< " = ";
|
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
|
||||||
if (is_static_cast(x.primitive())) {
|
if (is_static_cast(x.primitive())) {
|
||||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
os += fmt::format(
|
||||||
<< namer.get_name(x.inputs()[0]) << ");\n";
|
"static_cast<{0}>(tmp_{1});\n",
|
||||||
|
get_type_string(x.dtype()),
|
||||||
|
namer.get_name(x.inputs()[0]));
|
||||||
} else {
|
} else {
|
||||||
x.primitive().print(os);
|
std::ostringstream ss;
|
||||||
os << "()(";
|
x.primitive().print(ss);
|
||||||
|
os += ss.str();
|
||||||
|
os += "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||||
}
|
}
|
||||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
|
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write the outputs from tmps
|
// Write the outputs from tmps
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||||
<< ";\n";
|
|
||||||
}
|
}
|
||||||
// Increment indices and close per thread loop
|
// Increment indices and close per thread loop
|
||||||
if (work_per_thread > 1) {
|
if (work_per_thread > 1) {
|
||||||
@ -211,18 +240,18 @@ inline void build_kernel(
|
|||||||
auto& x = nc_inputs[i];
|
auto& x = nc_inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
if (!dynamic_dims) {
|
if (!dynamic_dims) {
|
||||||
os << " index_" << xname << " += "
|
os += fmt::format(
|
||||||
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
|
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
|
||||||
} else {
|
} else {
|
||||||
os << " index_" << xname << " += "
|
os += fmt::format(
|
||||||
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
|
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
os << " index++;\n }\n";
|
os += " index++;\n }\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish the kernel
|
// Finish the kernel
|
||||||
os << "}\n";
|
os += "}\n";
|
||||||
|
|
||||||
if (cnt > 31) {
|
if (cnt > 31) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -246,9 +275,9 @@ void Compiled::eval_gpu(
|
|||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto lib = d.get_library(kernel_lib_, [&]() {
|
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||||
std::ostringstream kernel;
|
std::string kernel = metal::utils();
|
||||||
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
concatenate(
|
||||||
<< metal::ternary_ops();
|
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
kernel_lib_ + "_contiguous",
|
kernel_lib_ + "_contiguous",
|
||||||
@ -261,7 +290,7 @@ void Compiled::eval_gpu(
|
|||||||
/* dynamic_dims = */ false);
|
/* dynamic_dims = */ false);
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
kernel_lib_ + "_contiguous_big",
|
kernel_lib_ + "_contiguous_large",
|
||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
@ -282,7 +311,21 @@ void Compiled::eval_gpu(
|
|||||||
/* ndim = */ i,
|
/* ndim = */ i,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
/* use_big_index = */ false,
|
/* use_big_index = */ false,
|
||||||
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
|
/* work_per_thread = */ i > 3 ? 2 : 1);
|
||||||
|
if (i > 1) {
|
||||||
|
build_kernel(
|
||||||
|
kernel,
|
||||||
|
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
|
||||||
|
inputs_,
|
||||||
|
outputs_,
|
||||||
|
tape_,
|
||||||
|
constant_ids_,
|
||||||
|
/* contiguous = */ false,
|
||||||
|
/* ndim = */ i,
|
||||||
|
/* dynamic_dims = */ false,
|
||||||
|
/* use_big_index = */ true,
|
||||||
|
/* work_per_thread = */ i > 3 ? 4 : 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
@ -295,13 +338,25 @@ void Compiled::eval_gpu(
|
|||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ true,
|
/* dynamic_dims = */ true,
|
||||||
/* use_big_index = */ false,
|
/* use_big_index = */ false,
|
||||||
/* work_per_thread = */ WORK_PER_THREAD);
|
/* work_per_thread = */ 2);
|
||||||
return kernel.str();
|
build_kernel(
|
||||||
|
kernel,
|
||||||
|
kernel_lib_ + "_strided_dynamic_large",
|
||||||
|
inputs_,
|
||||||
|
outputs_,
|
||||||
|
tape_,
|
||||||
|
constant_ids_,
|
||||||
|
/* contiguous = */ false,
|
||||||
|
/* ndim = */ 0,
|
||||||
|
/* dynamic_dims = */ true,
|
||||||
|
/* use_big_index = */ true,
|
||||||
|
/* work_per_thread = */ 4);
|
||||||
|
return kernel;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
// Figure out which kernel we are using
|
||||||
auto& output_shape = outputs[0].shape();
|
auto& output_shape = outputs[0].shape();
|
||||||
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
auto contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||||
|
|
||||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
// handle all broadcasting.
|
// handle all broadcasting.
|
||||||
@ -349,13 +404,19 @@ void Compiled::eval_gpu(
|
|||||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool use_2d = false;
|
bool large;
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
size_t max_size = 0;
|
size_t max_size = 0;
|
||||||
for (auto& in : inputs) {
|
for (auto& in : inputs) {
|
||||||
max_size = std::max(max_size, in.data_size());
|
max_size = std::max(max_size, in.data_size());
|
||||||
}
|
}
|
||||||
use_2d = (max_size > UINT32_MAX);
|
large = (max_size > UINT32_MAX);
|
||||||
|
} else {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (auto& o : outputs) {
|
||||||
|
max_size = std::max(max_size, o.size());
|
||||||
|
}
|
||||||
|
large = (max_size > UINT32_MAX);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel from the lib
|
// Get the kernel from the lib
|
||||||
@ -368,8 +429,9 @@ void Compiled::eval_gpu(
|
|||||||
} else {
|
} else {
|
||||||
kernel_name += std::to_string(shape.size());
|
kernel_name += std::to_string(shape.size());
|
||||||
}
|
}
|
||||||
} else if (use_2d) {
|
}
|
||||||
kernel_name += "_big";
|
if (large) {
|
||||||
|
kernel_name += "_large";
|
||||||
}
|
}
|
||||||
auto kernel = d.get_kernel(kernel_name, lib);
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -422,7 +484,7 @@ void Compiled::eval_gpu(
|
|||||||
MTL::Size group_dims(
|
MTL::Size group_dims(
|
||||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||||
|
|
||||||
MTL::Size grid_dims = use_2d
|
MTL::Size grid_dims = large
|
||||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
@ -430,7 +492,7 @@ void Compiled::eval_gpu(
|
|||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||||
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
|
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
int pow2;
|
int pow2;
|
||||||
|
@ -74,40 +74,42 @@ void copy_gpu_inplace(
|
|||||||
};
|
};
|
||||||
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
bool large;
|
||||||
bool use_2d = out.data_size() > UINT32_MAX;
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
|
// Allow for negative strides
|
||||||
|
large = out.data_size() > INT32_MAX;
|
||||||
|
} else {
|
||||||
|
large = out.data_size() > UINT32_MAX;
|
||||||
|
}
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
int work_per_thread = 1;
|
int work_per_thread = 1;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
{
|
|
||||||
std::ostringstream kname;
|
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
kname << (use_2d ? "s2" : "s");
|
kernel_name = (large ? "s2" : "s");
|
||||||
break;
|
break;
|
||||||
case CopyType::Vector:
|
case CopyType::Vector:
|
||||||
kname << (use_2d ? "v2" : "v");
|
kernel_name = (large ? "v2" : "v");
|
||||||
break;
|
break;
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
kname << "g";
|
kernel_name = "g";
|
||||||
break;
|
break;
|
||||||
case CopyType::GeneralGeneral:
|
case CopyType::GeneralGeneral:
|
||||||
kname << "gg";
|
kernel_name = "gg";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||||
kname << shape.size();
|
kernel_name += std::to_string(shape.size());
|
||||||
} else {
|
} else {
|
||||||
work_per_thread = 4;
|
work_per_thread = large ? 4 : 2;
|
||||||
kname << "n4";
|
concatenate(kernel_name, "n", std::to_string(work_per_thread));
|
||||||
|
}
|
||||||
|
if (large) {
|
||||||
|
kernel_name += "large";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
kname << "_copy";
|
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||||
kname << type_to_name(in) << type_to_name(out);
|
|
||||||
kernel_name = kname.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -159,7 +161,7 @@ void copy_gpu_inplace(
|
|||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
@ -193,9 +195,9 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
bool use_2d = out.data_size() > UINT32_MAX;
|
bool large = out.data_size() > UINT32_MAX;
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
|
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||||
type_to_name(val) + type_to_name(out);
|
type_to_name(val) + type_to_name(out);
|
||||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -210,7 +212,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
|||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
@ -53,27 +53,31 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||||
size_t ndim = src.ndim();
|
size_t ndim = src.ndim();
|
||||||
|
|
||||||
std::string lib_name;
|
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
|
||||||
std::string kernel_name;
|
bool large_src = src.size() > UINT32_MAX;
|
||||||
|
bool large_out = out.size() > UINT32_MAX;
|
||||||
|
bool large = large_index || large_src || large_out;
|
||||||
|
|
||||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||||
{
|
std::string kernel_name = fmt::format(
|
||||||
std::ostringstream kname;
|
"gather{0}{1}_{2}_{3}_{4}",
|
||||||
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
|
type_to_name(out),
|
||||||
<< "_" << idx_ndim;
|
idx_type_name,
|
||||||
lib_name = kname.str();
|
nidx,
|
||||||
kernel_name = lib_name;
|
idx_ndim,
|
||||||
}
|
large ? "size_t" : "uint");
|
||||||
|
std::string lib_name = kernel_name;
|
||||||
|
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::string kernel_source = metal::utils();
|
||||||
kernel_source << metal::utils() << metal::gather();
|
kernel_source += metal::gather();
|
||||||
std::string out_type_str = get_type_string(out.dtype());
|
std::string out_type_str = get_type_string(out.dtype());
|
||||||
std::string idx_type_str =
|
std::string idx_type_str =
|
||||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||||
|
|
||||||
// Index dimension specializations
|
// Index dimension specializations
|
||||||
kernel_source << fmt::format(
|
kernel_source += fmt::format(
|
||||||
gather_kernels,
|
gather_kernels,
|
||||||
type_to_name(out) + idx_type_name,
|
type_to_name(out) + idx_type_name,
|
||||||
out_type_str,
|
out_type_str,
|
||||||
@ -81,8 +85,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
nidx,
|
nidx,
|
||||||
idx_args,
|
idx_args,
|
||||||
idx_arr,
|
idx_arr,
|
||||||
idx_ndim);
|
idx_ndim,
|
||||||
return kernel_source.str();
|
large ? "size_t" : "uint");
|
||||||
|
return kernel_source;
|
||||||
});
|
});
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -209,8 +214,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
nwork = 32;
|
nwork = 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string lib_name;
|
|
||||||
std::string kernel_name;
|
|
||||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||||
std::string op_name;
|
std::string op_name;
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
@ -231,18 +234,24 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
auto upd_contig = upd.flags().row_contiguous;
|
auto upd_contig = upd.flags().row_contiguous;
|
||||||
{
|
bool large_out = out.size() > UINT32_MAX;
|
||||||
std::ostringstream kname;
|
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
|
||||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
bool large_upd = upd.size() > UINT32_MAX;
|
||||||
kname << "_" << op_name << "_" << nidx << "_"
|
bool large = large_out || large_idx || large_upd;
|
||||||
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
|
std::string kernel_name = fmt::format(
|
||||||
lib_name = kname.str();
|
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
|
||||||
kernel_name = kname.str();
|
type_to_name(out),
|
||||||
}
|
idx_type_name,
|
||||||
|
op_name,
|
||||||
|
nidx,
|
||||||
|
upd_contig ? "updc_true" : "updc_false",
|
||||||
|
nwork,
|
||||||
|
large ? "size_t" : "uint");
|
||||||
|
std::string lib_name = kernel_name;
|
||||||
|
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::string kernel_source = metal::utils();
|
||||||
kernel_source << metal::utils() << metal::reduce_utils()
|
concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
|
||||||
<< metal::scatter();
|
|
||||||
|
|
||||||
std::string out_type_str = get_type_string(out.dtype());
|
std::string out_type_str = get_type_string(out.dtype());
|
||||||
std::string idx_type_str =
|
std::string idx_type_str =
|
||||||
@ -270,7 +279,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||||
|
|
||||||
kernel_source << fmt::format(
|
kernel_source += fmt::format(
|
||||||
scatter_kernels,
|
scatter_kernels,
|
||||||
type_to_name(out) + idx_type_name + "_" + op_name,
|
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||||
out_type_str,
|
out_type_str,
|
||||||
@ -280,8 +289,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
idx_args,
|
idx_args,
|
||||||
idx_arr,
|
idx_arr,
|
||||||
upd_contig,
|
upd_contig,
|
||||||
nwork);
|
nwork,
|
||||||
return kernel_source.str();
|
large ? "size_t" : "uint");
|
||||||
|
return kernel_source;
|
||||||
});
|
});
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
constexpr std::string_view gather_kernels = R"(
|
constexpr std::string_view gather_kernels = R"(
|
||||||
[[kernel]] void gather{0}_{3}_{6}(
|
[[kernel]] void gather{0}_{3}_{6}_{7}(
|
||||||
const device {1}* src [[buffer(0)]],
|
const device {1}* src [[buffer(0)]],
|
||||||
device {1}* out [[buffer(1)]],
|
device {1}* out [[buffer(1)]],
|
||||||
const constant int* src_shape [[buffer(2)]],
|
const constant int* src_shape [[buffer(2)]],
|
||||||
@ -19,7 +19,7 @@ constexpr std::string_view gather_kernels = R"(
|
|||||||
Indices<{2}, {3}> idxs{{
|
Indices<{2}, {3}> idxs{{
|
||||||
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||||
|
|
||||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
return gather_impl<{1}, {2}, {3}, {6}, {7}>(
|
||||||
src,
|
src,
|
||||||
out,
|
out,
|
||||||
src_shape,
|
src_shape,
|
||||||
@ -34,7 +34,7 @@ constexpr std::string_view gather_kernels = R"(
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
constexpr std::string_view scatter_kernels = R"(
|
constexpr std::string_view scatter_kernels = R"(
|
||||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
|
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
|
||||||
const device {1}* updates [[buffer(1)]],
|
const device {1}* updates [[buffer(1)]],
|
||||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||||
const constant int* upd_shape [[buffer(3)]],
|
const constant int* upd_shape [[buffer(3)]],
|
||||||
@ -54,7 +54,7 @@ constexpr std::string_view scatter_kernels = R"(
|
|||||||
uint2 gid [[thread_position_in_grid]]) {{
|
uint2 gid [[thread_position_in_grid]]) {{
|
||||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||||
|
|
||||||
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
|
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
|
||||||
updates,
|
updates,
|
||||||
out,
|
out,
|
||||||
upd_shape,
|
upd_shape,
|
||||||
|
@ -46,25 +46,27 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
auto in_t = get_type_string(in_type);
|
auto in_t = get_type_string(in_type);
|
||||||
auto out_t = get_type_string(out_type);
|
auto out_t = get_type_string(out_type);
|
||||||
std::ostringstream kernel_source;
|
std::string kernel_source = metal::utils();
|
||||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
||||||
kernel_source << get_template_definition(
|
kernel_source +=
|
||||||
"v_" + lib_name, "unary_v", in_t, out_t, op);
|
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||||
kernel_source << get_template_definition(
|
kernel_source +=
|
||||||
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||||
kernel_source << get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
|
||||||
return kernel_source.str();
|
kernel_source += get_template_definition(
|
||||||
|
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||||
|
return kernel_source;
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_binary_kernels(
|
void append_binary_kernels(
|
||||||
const std::string lib_name,
|
const std::string lib_name,
|
||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op,
|
const std::string op,
|
||||||
std::ostringstream& kernel_source) {
|
std::string& kernel_source) {
|
||||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||||
{"ss", "binary_ss"},
|
{"ss", "binary_ss"},
|
||||||
{"vs", "binary_vs"},
|
{"vs", "binary_vs"},
|
||||||
@ -74,26 +76,24 @@ void add_binary_kernels(
|
|||||||
{"sv2", "binary_sv2"},
|
{"sv2", "binary_sv2"},
|
||||||
{"vv2", "binary_vv2"},
|
{"vv2", "binary_vv2"},
|
||||||
{"g1", "binary_g_nd1"},
|
{"g1", "binary_g_nd1"},
|
||||||
{"g2", "binary_g_nd2"},
|
{"g2large", "binary_g_nd2"},
|
||||||
{"g3", "binary_g_nd3"},
|
{"g3large", "binary_g_nd3"},
|
||||||
}};
|
}};
|
||||||
|
auto in_t = get_type_string(in_type);
|
||||||
|
auto out_t = get_type_string(out_type);
|
||||||
|
|
||||||
for (auto& [name, func] : kernel_types) {
|
for (auto& [name, func] : kernel_types) {
|
||||||
std::string template_def;
|
kernel_source +=
|
||||||
template_def = get_template_definition(
|
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||||
name + "_" + lib_name,
|
|
||||||
func,
|
|
||||||
get_type_string(in_type),
|
|
||||||
get_type_string(out_type),
|
|
||||||
op);
|
|
||||||
kernel_source << template_def;
|
|
||||||
}
|
}
|
||||||
kernel_source << get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
"gn4_" + lib_name,
|
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
|
||||||
"binary_g",
|
kernel_source += get_template_definition(
|
||||||
get_type_string(in_type),
|
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
|
||||||
get_type_string(out_type),
|
kernel_source += get_template_definition(
|
||||||
op,
|
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
|
||||||
4);
|
kernel_source += get_template_definition(
|
||||||
|
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_binary_kernel(
|
MTL::ComputePipelineState* get_binary_kernel(
|
||||||
@ -104,10 +104,11 @@ MTL::ComputePipelineState* get_binary_kernel(
|
|||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::string kernel_source;
|
||||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
|
kernel_source = metal::utils();
|
||||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
concatenate(kernel_source, metal::binary_ops(), metal::binary());
|
||||||
return kernel_source.str();
|
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||||
|
return kernel_source;
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
@ -120,11 +121,10 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
|||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::string kernel_source = metal::utils();
|
||||||
kernel_source << metal::utils() << metal::binary_ops()
|
concatenate(kernel_source, metal::binary_ops(), metal::binary_two());
|
||||||
<< metal::binary_two();
|
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
return kernel_source;
|
||||||
return kernel_source.str();
|
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
@ -136,24 +136,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
auto t_str = get_type_string(type);
|
||||||
|
std::string kernel_source = metal::utils();
|
||||||
|
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||||
{"v", "ternary_v"},
|
{"v", "ternary_v"},
|
||||||
{"v2", "ternary_v2"},
|
{"v2", "ternary_v2"},
|
||||||
{"g1", "ternary_g_nd1"},
|
{"g1", "ternary_g_nd1"},
|
||||||
{"g2", "ternary_g_nd2"},
|
{"g2large", "ternary_g_nd2"},
|
||||||
{"g3", "ternary_g_nd3"},
|
{"g3large", "ternary_g_nd3"},
|
||||||
}};
|
}};
|
||||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
|
|
||||||
for (auto& [name, func] : kernel_types) {
|
for (auto& [name, func] : kernel_types) {
|
||||||
std::string template_def;
|
kernel_source +=
|
||||||
template_def = get_template_definition(
|
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||||
name + "_" + lib_name, func, get_type_string(type), op);
|
|
||||||
kernel_source << template_def;
|
|
||||||
}
|
}
|
||||||
kernel_source << get_template_definition(
|
kernel_source += get_template_definition(
|
||||||
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
|
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
|
||||||
return kernel_source.str();
|
kernel_source += get_template_definition(
|
||||||
|
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
|
||||||
|
return kernel_source;
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
@ -165,31 +170,43 @@ MTL::ComputePipelineState* get_copy_kernel(
|
|||||||
const array& out) {
|
const array& out) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::string kernel_source = metal::utils();
|
||||||
|
kernel_source += metal::copy();
|
||||||
auto in_type = get_type_string(in.dtype());
|
auto in_type = get_type_string(in.dtype());
|
||||||
auto out_type = get_type_string(out.dtype());
|
auto out_type = get_type_string(out.dtype());
|
||||||
kernel_source << metal::utils() << metal::copy()
|
kernel_source +=
|
||||||
<< get_template_definition(
|
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
|
||||||
"s_" + lib_name, "copy_s", in_type, out_type)
|
kernel_source +=
|
||||||
<< get_template_definition(
|
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||||
"v_" + lib_name, "copy_v", in_type, out_type)
|
kernel_source += get_template_definition(
|
||||||
<< get_template_definition(
|
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
kernel_source += get_template_definition(
|
||||||
<< get_template_definition(
|
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
|
||||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
kernel_source += get_template_definition(
|
||||||
<< get_template_definition(
|
"g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int");
|
||||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
kernel_source += get_template_definition(
|
||||||
<< get_template_definition(
|
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
|
||||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
kernel_source += get_template_definition(
|
||||||
<< get_template_definition(
|
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
kernel_source += get_template_definition(
|
||||||
<< get_template_definition(
|
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
|
||||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
kernel_source += get_template_definition(
|
||||||
<< get_template_definition(
|
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
|
||||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
kernel_source += get_template_definition(
|
||||||
<< get_template_definition(
|
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
|
||||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
kernel_source += get_template_definition(
|
||||||
return kernel_source.str();
|
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||||
|
return kernel_source;
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
@ -77,12 +77,12 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t& a_stride,
|
constant const size_t& a_stride,
|
||||||
constant const size_t& b_stride,
|
constant const size_t& b_stride,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||||
[[kernel]] void binary_g_nd2(
|
[[kernel]] void binary_g_nd2(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -91,13 +91,13 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t b_strides[2],
|
constant const size_t b_strides[2],
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||||
[[kernel]] void binary_g_nd3(
|
[[kernel]] void binary_g_nd3(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -106,14 +106,18 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t b_strides[3],
|
constant const size_t b_strides[3],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||||
size_t out_idx =
|
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
|
||||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N = 1>
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int N = 1,
|
||||||
|
typename IdxT = size_t>
|
||||||
[[kernel]] void binary_g(
|
[[kernel]] void binary_g(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -124,13 +128,12 @@ template <typename T, typename U, typename Op, int N = 1>
|
|||||||
constant const int& ndim,
|
constant const int& ndim,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd(
|
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||||
auto xshape = shape[ndim - 1];
|
auto xshape = shape[ndim - 1];
|
||||||
size_t out_idx =
|
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
IdxT a_xstride = a_strides[ndim - 1];
|
||||||
auto a_xstride = a_strides[ndim - 1];
|
IdxT b_xstride = b_strides[ndim - 1];
|
||||||
auto b_xstride = b_strides[ndim - 1];
|
|
||||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
|
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
|
||||||
idx.x += a_xstride;
|
idx.x += a_xstride;
|
||||||
|
@ -17,10 +17,13 @@
|
|||||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||||
|
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
|
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||||
|
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||||
|
|
||||||
#define instantiate_binary_integer(op) \
|
#define instantiate_binary_integer(op) \
|
||||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||||
|
@ -99,14 +99,14 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t& a_stride,
|
constant const size_t& a_stride,
|
||||||
constant const size_t& b_stride,
|
constant const size_t& b_stride,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||||
auto out = Op()(a[a_idx], b[b_idx]);
|
auto out = Op()(a[a_idx], b[b_idx]);
|
||||||
c[index] = out[0];
|
c[index] = out[0];
|
||||||
d[index] = out[1];
|
d[index] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||||
[[kernel]] void binary_g_nd2(
|
[[kernel]] void binary_g_nd2(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -116,15 +116,15 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t b_strides[2],
|
constant const size_t b_strides[2],
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||||
auto out = Op()(a[a_idx], b[b_idx]);
|
auto out = Op()(a[a_idx], b[b_idx]);
|
||||||
c[out_idx] = out[0];
|
c[out_idx] = out[0];
|
||||||
d[out_idx] = out[1];
|
d[out_idx] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||||
[[kernel]] void binary_g_nd3(
|
[[kernel]] void binary_g_nd3(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -134,16 +134,20 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t b_strides[3],
|
constant const size_t b_strides[3],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||||
size_t out_idx =
|
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
|
||||||
auto out = Op()(a[a_idx], b[b_idx]);
|
auto out = Op()(a[a_idx], b[b_idx]);
|
||||||
c[out_idx] = out[0];
|
c[out_idx] = out[0];
|
||||||
d[out_idx] = out[1];
|
d[out_idx] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N = 1>
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int N = 1,
|
||||||
|
typename IdxT = size_t>
|
||||||
[[kernel]] void binary_g(
|
[[kernel]] void binary_g(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -155,13 +159,12 @@ template <typename T, typename U, typename Op, int N = 1>
|
|||||||
constant const int& ndim,
|
constant const int& ndim,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd(
|
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||||
auto xshape = shape[ndim - 1];
|
auto xshape = shape[ndim - 1];
|
||||||
size_t out_idx =
|
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
IdxT a_xstride = a_strides[ndim - 1];
|
||||||
auto a_xstride = a_strides[ndim - 1];
|
IdxT b_xstride = b_strides[ndim - 1];
|
||||||
auto b_xstride = b_strides[ndim - 1];
|
|
||||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
auto out = Op()(a[idx.x], b[idx.y]);
|
auto out = Op()(a[idx.x], b[idx.y]);
|
||||||
c[out_idx] = out[0];
|
c[out_idx] = out[0];
|
||||||
|
@ -15,10 +15,13 @@
|
|||||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||||
|
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||||
|
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
|
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||||
|
|
||||||
#define instantiate_binary_float(op) \
|
#define instantiate_binary_float(op) \
|
||||||
instantiate_binary_all(op, float16, half, half) \
|
instantiate_binary_all(op, float16, half, half) \
|
||||||
|
@ -42,36 +42,36 @@ template <typename T, typename U>
|
|||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant const int64_t& src_stride [[buffer(3)]],
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||||
dst[index] = static_cast<U>(src[src_idx]);
|
dst[index] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, typename IdxT = int64_t>
|
||||||
[[kernel]] void copy_g_nd2(
|
[[kernel]] void copy_g_nd2(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, typename IdxT = int64_t>
|
||||||
[[kernel]] void copy_g_nd3(
|
[[kernel]] void copy_g_nd3(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||||
int64_t dst_idx =
|
IdxT dst_idx =
|
||||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N = 1>
|
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||||
[[kernel]] void copy_g(
|
[[kernel]] void copy_g(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
@ -80,17 +80,16 @@ template <typename T, typename U, int N = 1>
|
|||||||
constant const int& ndim [[buffer(5)]],
|
constant const int& ndim [[buffer(5)]],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc(
|
auto src_idx = elem_to_loc<int64_t, IdxT>(
|
||||||
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
||||||
if (N == 1) {
|
if (N == 1) {
|
||||||
int64_t dst_idx =
|
IdxT dst_idx =
|
||||||
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
|
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto xshape = src_shape[ndim - 1];
|
auto xshape = src_shape[ndim - 1];
|
||||||
int64_t dst_idx =
|
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
|
|
||||||
auto src_xstride = src_strides[ndim - 1];
|
auto src_xstride = src_strides[ndim - 1];
|
||||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
|
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
|
||||||
@ -105,36 +104,36 @@ template <typename T, typename U>
|
|||||||
constant const int64_t& src_stride [[buffer(3)]],
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
constant const int64_t& dst_stride [[buffer(4)]],
|
constant const int64_t& dst_stride [[buffer(4)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, typename IdxT = int64_t>
|
||||||
[[kernel]] void copy_gg_nd2(
|
[[kernel]] void copy_gg_nd2(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
uint2 index [[thread_position_in_grid]]) {
|
uint2 index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, typename IdxT = int64_t>
|
||||||
[[kernel]] void copy_gg_nd3(
|
[[kernel]] void copy_gg_nd3(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N = 1>
|
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||||
[[kernel]] void copy_gg(
|
[[kernel]] void copy_gg(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
@ -143,7 +142,7 @@ template <typename T, typename U, int N = 1>
|
|||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
constant const int& ndim [[buffer(5)]],
|
constant const int& ndim [[buffer(5)]],
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd(
|
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
|
||||||
{N * index.x, index.y, index.z},
|
{N * index.x, index.y, index.z},
|
||||||
src_shape,
|
src_shape,
|
||||||
src_strides,
|
src_strides,
|
||||||
@ -153,8 +152,8 @@ template <typename T, typename U, int N = 1>
|
|||||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto src_xstride = src_strides[ndim - 1];
|
IdxT src_xstride = src_strides[ndim - 1];
|
||||||
auto dst_xstride = dst_strides[ndim - 1];
|
IdxT dst_xstride = dst_strides[ndim - 1];
|
||||||
auto xshape = src_shape[ndim - 1];
|
auto xshape = src_shape[ndim - 1];
|
||||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||||
|
@ -10,13 +10,19 @@
|
|||||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
|
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
|
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||||
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
|
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||||
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
|
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||||
|
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||||
|
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||||
|
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||||
|
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||||
|
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||||
|
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
|
||||||
|
|
||||||
#define instantiate_copy_itype(itname, itype) \
|
#define instantiate_copy_itype(itname, itype) \
|
||||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/indexing.h"
|
#include "mlx/backend/metal/kernels/indexing.h"
|
||||||
|
|
||||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||||
METAL_FUNC void gather_impl(
|
METAL_FUNC void gather_impl(
|
||||||
const device T* src [[buffer(0)]],
|
const device T* src [[buffer(0)]],
|
||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
@ -16,18 +16,18 @@ METAL_FUNC void gather_impl(
|
|||||||
const thread Indices<IdxT, NIDX>& indices,
|
const thread Indices<IdxT, NIDX>& indices,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
size_t src_idx = 0;
|
LocT src_idx = 0;
|
||||||
for (int i = 0; i < NIDX; ++i) {
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
size_t idx_loc;
|
LocT idx_loc;
|
||||||
if (IDX_NDIM == 0) {
|
if (IDX_NDIM == 0) {
|
||||||
idx_loc = 0;
|
idx_loc = 0;
|
||||||
} else if (IDX_NDIM == 1) {
|
} else if (IDX_NDIM == 1) {
|
||||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||||
} else {
|
} else {
|
||||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||||
idx_loc += indices.row_contiguous[i]
|
idx_loc += indices.row_contiguous[i]
|
||||||
? index.y
|
? index.y
|
||||||
: elem_to_loc(
|
: elem_to_loc<size_t, LocT>(
|
||||||
index.y,
|
index.y,
|
||||||
&indices.shapes[indices.ndim * i + 1],
|
&indices.shapes[indices.ndim * i + 1],
|
||||||
&indices.strides[indices.ndim * i + 1],
|
&indices.strides[indices.ndim * i + 1],
|
||||||
@ -35,17 +35,17 @@ METAL_FUNC void gather_impl(
|
|||||||
}
|
}
|
||||||
auto ax = axes[i];
|
auto ax = axes[i];
|
||||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||||
src_idx += idx_val * src_strides[ax];
|
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
|
auto src_offset =
|
||||||
|
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
||||||
|
|
||||||
size_t out_idx = index.z;
|
LocT out_idx = index.z;
|
||||||
if (IDX_NDIM == 1) {
|
if (IDX_NDIM == 1) {
|
||||||
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
|
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
|
||||||
} else if (IDX_NDIM >= 2) {
|
} else if (IDX_NDIM >= 2) {
|
||||||
out_idx +=
|
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
|
||||||
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
|
|
||||||
}
|
}
|
||||||
out[out_idx] = src[src_offset + src_idx];
|
out[out_idx] = src[src_offset + src_idx];
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ struct Indices {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename IdxT>
|
template <typename IdxT>
|
||||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
|
||||||
if (is_unsigned_v<IdxT>) {
|
if (is_unsigned_v<IdxT>) {
|
||||||
return idx;
|
return idx;
|
||||||
} else {
|
} else {
|
||||||
|
@ -10,7 +10,8 @@ template <
|
|||||||
typename Op,
|
typename Op,
|
||||||
int NIDX,
|
int NIDX,
|
||||||
bool UPD_ROW_CONTIG,
|
bool UPD_ROW_CONTIG,
|
||||||
int NWORK>
|
int NWORK,
|
||||||
|
typename LocT>
|
||||||
METAL_FUNC void scatter_impl(
|
METAL_FUNC void scatter_impl(
|
||||||
const device T* updates,
|
const device T* updates,
|
||||||
device mlx_atomic<T>* out,
|
device mlx_atomic<T>* out,
|
||||||
@ -28,29 +29,31 @@ METAL_FUNC void scatter_impl(
|
|||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
auto ind_idx = gid.y * NWORK;
|
auto ind_idx = gid.y * NWORK;
|
||||||
size_t out_offset = 0;
|
LocT out_offset = 0;
|
||||||
if (upd_size > 1) {
|
if (upd_size > 1) {
|
||||||
out_offset =
|
out_offset = elem_to_loc<size_t, LocT>(
|
||||||
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
|
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
|
||||||
size_t out_idx = out_offset;
|
LocT out_idx = out_offset;
|
||||||
for (int i = 0; i < NIDX; ++i) {
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
auto idx_loc = indices.row_contiguous[i]
|
auto idx_loc = indices.row_contiguous[i]
|
||||||
? ind_idx
|
? ind_idx
|
||||||
: elem_to_loc(
|
: elem_to_loc<size_t, LocT>(
|
||||||
ind_idx,
|
ind_idx,
|
||||||
&indices.shapes[indices.ndim * i],
|
&indices.shapes[indices.ndim * i],
|
||||||
&indices.strides[indices.ndim * i],
|
&indices.strides[indices.ndim * i],
|
||||||
indices.ndim);
|
indices.ndim);
|
||||||
auto ax = axes[i];
|
auto ax = axes[i];
|
||||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||||
out_idx += idx_val * out_strides[ax];
|
out_idx +=
|
||||||
|
static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
|
||||||
}
|
}
|
||||||
auto upd_idx = ind_idx * upd_size + gid.x;
|
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
|
||||||
if constexpr (!UPD_ROW_CONTIG) {
|
if constexpr (!UPD_ROW_CONTIG) {
|
||||||
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
|
upd_idx =
|
||||||
|
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||||
}
|
}
|
||||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||||
}
|
}
|
||||||
|
@ -32,13 +32,13 @@ template <typename T, typename Op>
|
|||||||
constant const size_t& b_strides,
|
constant const size_t& b_strides,
|
||||||
constant const size_t& c_strides,
|
constant const size_t& c_strides,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto a_idx = elem_to_loc_1(index, a_strides);
|
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_1(index, b_strides);
|
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_strides);
|
||||||
auto c_idx = elem_to_loc_1(index, c_strides);
|
auto c_idx = elem_to_loc_1<size_t, uint>(index, c_strides);
|
||||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op, typename IdxT = size_t>
|
||||||
[[kernel]] void ternary_g_nd2(
|
[[kernel]] void ternary_g_nd2(
|
||||||
device const bool* a,
|
device const bool* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -49,14 +49,14 @@ template <typename T, typename Op>
|
|||||||
constant const size_t c_strides[2],
|
constant const size_t c_strides[2],
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||||
auto c_idx = elem_to_loc_2(index, c_strides);
|
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides);
|
||||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op, typename IdxT = size_t>
|
||||||
[[kernel]] void ternary_g_nd3(
|
[[kernel]] void ternary_g_nd3(
|
||||||
device const bool* a,
|
device const bool* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -67,15 +67,14 @@ template <typename T, typename Op>
|
|||||||
constant const size_t c_strides[3],
|
constant const size_t c_strides[3],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides);
|
||||||
size_t out_idx =
|
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
|
||||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op, int N = 1>
|
template <typename T, typename Op, int N = 1, typename IdxT = size_t>
|
||||||
[[kernel]] void ternary_g(
|
[[kernel]] void ternary_g(
|
||||||
device const bool* a,
|
device const bool* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -88,7 +87,7 @@ template <typename T, typename Op, int N = 1>
|
|||||||
constant const int& ndim,
|
constant const int& ndim,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_3_nd(
|
auto idx = elem_to_loc_3_nd<IdxT>(
|
||||||
{N * index.x, index.y, index.z},
|
{N * index.x, index.y, index.z},
|
||||||
shape,
|
shape,
|
||||||
a_strides,
|
a_strides,
|
||||||
@ -96,11 +95,10 @@ template <typename T, typename Op, int N = 1>
|
|||||||
c_strides,
|
c_strides,
|
||||||
ndim);
|
ndim);
|
||||||
auto xshape = shape[ndim - 1];
|
auto xshape = shape[ndim - 1];
|
||||||
size_t out_idx =
|
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
IdxT a_xstride = a_strides[ndim - 1];
|
||||||
auto a_xstride = a_strides[ndim - 1];
|
IdxT b_xstride = b_strides[ndim - 1];
|
||||||
auto b_xstride = b_strides[ndim - 1];
|
IdxT c_xstride = c_strides[ndim - 1];
|
||||||
auto c_xstride = c_strides[ndim - 1];
|
|
||||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||||
idx.x += a_xstride;
|
idx.x += a_xstride;
|
||||||
|
@ -11,10 +11,13 @@
|
|||||||
#define instantiate_ternary_all(op, tname, type) \
|
#define instantiate_ternary_all(op, tname, type) \
|
||||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||||
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
|
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
|
||||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
|
||||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
|
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
|
||||||
|
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||||
|
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||||
|
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||||
|
|
||||||
#define instantiate_ternary_types(op) \
|
#define instantiate_ternary_types(op) \
|
||||||
instantiate_ternary_all(op, bool_, bool) \
|
instantiate_ternary_all(op, bool_, bool) \
|
||||||
|
@ -18,7 +18,12 @@ template <typename T, typename U, typename Op>
|
|||||||
out[offset] = Op()(in[offset]);
|
out[offset] = Op()(in[offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N = 1>
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int N = 1,
|
||||||
|
typename IdxT = size_t>
|
||||||
[[kernel]] void unary_g(
|
[[kernel]] void unary_g(
|
||||||
device const T* in,
|
device const T* in,
|
||||||
device U* out,
|
device U* out,
|
||||||
@ -27,12 +32,11 @@ template <typename T, typename U, typename Op, int N = 1>
|
|||||||
device const int& ndim,
|
device const int& ndim,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx =
|
auto idx = elem_to_loc<size_t, IdxT>(
|
||||||
elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
||||||
auto xshape = in_shape[ndim - 1];
|
auto xshape = in_shape[ndim - 1];
|
||||||
auto xstride = in_strides[ndim - 1];
|
IdxT xstride = in_strides[ndim - 1];
|
||||||
size_t out_idx =
|
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
|
||||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
out[out_idx++] = Op()(in[idx]);
|
out[out_idx++] = Op()(in[idx]);
|
||||||
idx += xstride;
|
idx += xstride;
|
||||||
|
@ -8,8 +8,10 @@
|
|||||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||||
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
instantiate_kernel( \
|
||||||
|
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||||
|
|
||||||
#define instantiate_unary_all_same(op, tname, type) \
|
#define instantiate_unary_all_same(op, tname, type) \
|
||||||
instantiate_unary_all(op, tname, tname, type, type)
|
instantiate_unary_all(op, tname, tname, type, type)
|
||||||
|
@ -89,44 +89,45 @@ struct Limits<complex64_t> {
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Single Array with generic dims
|
// Single Array with generic dims
|
||||||
|
|
||||||
template <typename stride_t>
|
template <typename StrideT, typename IdxT = StrideT>
|
||||||
METAL_FUNC stride_t elem_to_loc(
|
METAL_FUNC IdxT elem_to_loc(
|
||||||
uint elem,
|
uint elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const stride_t* strides,
|
constant const StrideT* strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
stride_t loc = 0;
|
IdxT loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||||
loc += (elem % shape[i]) * strides[i];
|
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename stride_t>
|
template <typename StrideT, typename IdxT = StrideT>
|
||||||
METAL_FUNC stride_t elem_to_loc(
|
METAL_FUNC IdxT elem_to_loc(
|
||||||
stride_t elem,
|
StrideT elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const stride_t* strides,
|
constant const StrideT* strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
stride_t loc = 0;
|
IdxT loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||||
loc += (elem % shape[i]) * strides[i];
|
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Non templated version to handle arbitrary dims
|
// Non templated version to handle arbitrary dims
|
||||||
template <typename stride_t>
|
template <typename StrideT, typename IdxT = StrideT>
|
||||||
METAL_FUNC stride_t elem_to_loc(
|
METAL_FUNC IdxT elem_to_loc(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const stride_t* strides,
|
constant const StrideT* strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
IdxT loc =
|
||||||
|
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
loc += (elem.z % shape[d]) * strides[d];
|
loc += (elem.z % shape[d]) * IdxT(strides[d]);
|
||||||
elem.z /= shape[d];
|
elem.z /= shape[d];
|
||||||
}
|
}
|
||||||
return loc;
|
return loc;
|
||||||
@ -135,61 +136,65 @@ METAL_FUNC stride_t elem_to_loc(
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Single Array with fixed N dims
|
// Single Array with fixed N dims
|
||||||
|
|
||||||
template <typename stride_t>
|
template <typename StrideT, typename IdxT = StrideT>
|
||||||
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
|
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
|
||||||
return elem * stride;
|
return elem * IdxT(stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename stride_t>
|
template <typename StrideT, typename IdxT = StrideT>
|
||||||
METAL_FUNC stride_t
|
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
|
||||||
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
|
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
|
||||||
return elem.x * strides[1] + elem.y * strides[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename stride_t>
|
template <typename StrideT, typename IdxT = StrideT>
|
||||||
METAL_FUNC stride_t
|
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
|
||||||
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
|
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
|
||||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
elem.z * IdxT(strides[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Multiple Arrays with generic dims
|
// Multiple Arrays with generic dims
|
||||||
|
|
||||||
template <typename stride_t>
|
template <typename StrideT, typename IdxT = StrideT>
|
||||||
METAL_FUNC ulong2 elem_to_loc_2_nd(
|
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const stride_t* a_strides,
|
constant const StrideT* a_strides,
|
||||||
constant const stride_t* b_strides,
|
constant const StrideT* b_strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
ulong2 loc = {
|
vec<IdxT, 2> loc = {
|
||||||
ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
IdxT(
|
||||||
ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
elem.x * IdxT(a_strides[ndim - 1]) +
|
||||||
|
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
|
||||||
|
IdxT(
|
||||||
|
elem.x * IdxT(b_strides[ndim - 1]) +
|
||||||
|
elem.y * IdxT(b_strides[ndim - 2]))};
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
uint l = elem.z % shape[d];
|
uint l = elem.z % shape[d];
|
||||||
loc.x += l * a_strides[d];
|
loc.x += l * IdxT(a_strides[d]);
|
||||||
loc.y += l * b_strides[d];
|
loc.y += l * IdxT(b_strides[d]);
|
||||||
elem.z /= shape[d];
|
elem.z /= shape[d];
|
||||||
}
|
}
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC ulong3 elem_to_loc_3_nd(
|
template <typename IdxT = size_t>
|
||||||
|
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const size_t* a_strides,
|
constant const size_t* a_strides,
|
||||||
constant const size_t* b_strides,
|
constant const size_t* b_strides,
|
||||||
constant const size_t* c_strides,
|
constant const size_t* c_strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
ulong3 loc = {
|
vec<IdxT, 3> loc = {
|
||||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
|
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
|
||||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
|
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
|
||||||
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
|
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
uint l = elem.z % shape[d];
|
uint l = elem.z % shape[d];
|
||||||
loc.x += l * a_strides[d];
|
loc.x += l * IdxT(a_strides[d]);
|
||||||
loc.y += l * b_strides[d];
|
loc.y += l * IdxT(b_strides[d]);
|
||||||
loc.z += l * c_strides[d];
|
loc.z += l * IdxT(c_strides[d]);
|
||||||
elem.z /= shape[d];
|
elem.z /= shape[d];
|
||||||
}
|
}
|
||||||
return loc;
|
return loc;
|
||||||
|
@ -36,27 +36,31 @@ void ternary_op_gpu_inplace(
|
|||||||
};
|
};
|
||||||
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
||||||
|
|
||||||
bool use_2d = out.data_size() > UINT_MAX;
|
bool large = out.data_size() > UINT_MAX;
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1;
|
int work_per_thread;
|
||||||
std::string kernel_name;
|
|
||||||
{
|
|
||||||
std::ostringstream kname;
|
|
||||||
if (topt == TernaryOpType::General) {
|
if (topt == TernaryOpType::General) {
|
||||||
kname << "g";
|
work_per_thread = large ? 4 : 2;
|
||||||
if (shape.size() <= 3) {
|
|
||||||
kname << shape.size();
|
|
||||||
} else if (work_per_thread > 1) {
|
|
||||||
kname << "n" << work_per_thread;
|
|
||||||
}
|
|
||||||
} else if (use_2d) {
|
|
||||||
kname << "v2";
|
|
||||||
} else {
|
} else {
|
||||||
kname << "v";
|
work_per_thread = 1;
|
||||||
}
|
}
|
||||||
kname << "_" << op << type_to_name(b);
|
std::string kernel_name;
|
||||||
kernel_name = kname.str();
|
if (topt == TernaryOpType::General) {
|
||||||
|
kernel_name = "g";
|
||||||
|
if (shape.size() <= 3) {
|
||||||
|
kernel_name += std::to_string(shape.size());
|
||||||
|
} else if (work_per_thread > 1) {
|
||||||
|
concatenate(kernel_name, "n", std::to_string(work_per_thread));
|
||||||
}
|
}
|
||||||
|
if (large) {
|
||||||
|
kernel_name += "large";
|
||||||
|
}
|
||||||
|
} else if (large) {
|
||||||
|
kernel_name = "v2";
|
||||||
|
} else {
|
||||||
|
kernel_name = "v";
|
||||||
|
}
|
||||||
|
concatenate(kernel_name, "_", op, type_to_name(b));
|
||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
@ -107,7 +111,7 @@ void ternary_op_gpu_inplace(
|
|||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
@ -35,16 +35,19 @@ void unary_op_gpu_inplace(
|
|||||||
};
|
};
|
||||||
auto [shape, strides] = maybe_collapse();
|
auto [shape, strides] = maybe_collapse();
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
int work_per_thread = !contig ? 4 : 1;
|
|
||||||
size_t nthreads = contig ? in.data_size() : in.size();
|
size_t nthreads = contig ? in.data_size() : in.size();
|
||||||
bool use_2d = nthreads > UINT32_MAX;
|
bool large = nthreads > UINT32_MAX;
|
||||||
|
int work_per_thread = !contig && large ? 4 : 1;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
if (contig) {
|
if (contig) {
|
||||||
kernel_name = (use_2d ? "v2" : "v");
|
kernel_name = (large ? "v2" : "v");
|
||||||
} else {
|
} else {
|
||||||
kernel_name = (work_per_thread == 4 ? "gn4" : "g");
|
kernel_name = "gn" + std::to_string(work_per_thread);
|
||||||
|
if (large) {
|
||||||
|
kernel_name += "_large";
|
||||||
}
|
}
|
||||||
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
|
}
|
||||||
|
concatenate(kernel_name, "_", op, type_to_name(in), type_to_name(out));
|
||||||
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
|
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
|
||||||
|
|
||||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
@ -73,7 +76,7 @@ void unary_op_gpu_inplace(
|
|||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
@ -61,4 +61,15 @@ inline void debug_set_primitive_buffer_label(
|
|||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive);
|
std::string get_primitive_string(Primitive* primitive);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void concatenate(std::string& acc, T first) {
|
||||||
|
acc += first;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename... Args>
|
||||||
|
void concatenate(std::string& acc, T first, Args... args) {
|
||||||
|
acc += first;
|
||||||
|
concatenate(acc, args...);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user