mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster indexing math in a few kernels (#1589)
* wip: faster compiled kernels * faster general unary with uint specialization * index type in compiled, unary, binary, ternary, copy * fix jit * jit fix * specialize gather + scatter * nit in docs
This commit is contained in:
@@ -77,12 +77,12 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -91,13 +91,13 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -106,14 +106,18 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -124,13 +128,12 @@ template <typename T, typename U, typename Op, int N = 1>
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
|
||||
idx.x += a_xstride;
|
||||
|
||||
@@ -9,18 +9,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
|
||||
@@ -99,14 +99,14 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -116,15 +116,15 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -134,16 +134,20 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@@ -155,13 +159,12 @@ template <typename T, typename U, typename Op, int N = 1>
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
|
||||
@@ -7,18 +7,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
|
||||
@@ -42,36 +42,36 @@ template <typename T, typename U>
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
IdxT dst_idx =
|
||||
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1>
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
@@ -80,17 +80,16 @@ template <typename T, typename U, int N = 1>
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(
|
||||
auto src_idx = elem_to_loc<int64_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
||||
if (N == 1) {
|
||||
int64_t dst_idx =
|
||||
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
|
||||
IdxT dst_idx =
|
||||
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
return;
|
||||
}
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
int64_t dst_idx =
|
||||
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
|
||||
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto src_xstride = src_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
|
||||
@@ -105,36 +104,36 @@ template <typename T, typename U>
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1>
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
@@ -143,7 +142,7 @@ template <typename T, typename U, int N = 1>
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
|
||||
{N * index.x, index.y, index.z},
|
||||
src_shape,
|
||||
src_strides,
|
||||
@@ -153,8 +152,8 @@ template <typename T, typename U, int N = 1>
|
||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||
return;
|
||||
}
|
||||
auto src_xstride = src_strides[ndim - 1];
|
||||
auto dst_xstride = dst_strides[ndim - 1];
|
||||
IdxT src_xstride = src_strides[ndim - 1];
|
||||
IdxT dst_xstride = dst_strides[ndim - 1];
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||
|
||||
@@ -4,19 +4,25 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@@ -16,18 +16,18 @@ METAL_FUNC void gather_impl(
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
size_t src_idx = 0;
|
||||
LocT src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
LocT idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
} else {
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
idx_loc += indices.row_contiguous[i]
|
||||
? index.y
|
||||
: elem_to_loc(
|
||||
: elem_to_loc<size_t, LocT>(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
@@ -35,17 +35,17 @@ METAL_FUNC void gather_impl(
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
|
||||
auto src_offset =
|
||||
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.z;
|
||||
LocT out_idx = index.z;
|
||||
if (IDX_NDIM == 1) {
|
||||
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
|
||||
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
|
||||
} else if (IDX_NDIM >= 2) {
|
||||
out_idx +=
|
||||
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
|
||||
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
|
||||
}
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ struct Indices {
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
|
||||
if (is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
|
||||
@@ -10,7 +10,8 @@ template <
|
||||
typename Op,
|
||||
int NIDX,
|
||||
bool UPD_ROW_CONTIG,
|
||||
int NWORK>
|
||||
int NWORK,
|
||||
typename LocT>
|
||||
METAL_FUNC void scatter_impl(
|
||||
const device T* updates,
|
||||
device mlx_atomic<T>* out,
|
||||
@@ -28,29 +29,31 @@ METAL_FUNC void scatter_impl(
|
||||
Op op;
|
||||
|
||||
auto ind_idx = gid.y * NWORK;
|
||||
size_t out_offset = 0;
|
||||
LocT out_offset = 0;
|
||||
if (upd_size > 1) {
|
||||
out_offset =
|
||||
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
out_offset = elem_to_loc<size_t, LocT>(
|
||||
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
}
|
||||
|
||||
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
|
||||
size_t out_idx = out_offset;
|
||||
LocT out_idx = out_offset;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = indices.row_contiguous[i]
|
||||
? ind_idx
|
||||
: elem_to_loc(
|
||||
: elem_to_loc<size_t, LocT>(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
out_idx +=
|
||||
static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
|
||||
}
|
||||
auto upd_idx = ind_idx * upd_size + gid.x;
|
||||
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
|
||||
if constexpr (!UPD_ROW_CONTIG) {
|
||||
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
upd_idx =
|
||||
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
}
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
|
||||
@@ -32,13 +32,13 @@ template <typename T, typename Op>
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1(index, c_strides);
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1<size_t, uint>(index, c_strides);
|
||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void ternary_g_nd2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
@@ -49,14 +49,14 @@ template <typename T, typename Op>
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2(index, c_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void ternary_g_nd3(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
@@ -67,15 +67,14 @@ template <typename T, typename Op>
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int N = 1>
|
||||
template <typename T, typename Op, int N = 1, typename IdxT = size_t>
|
||||
[[kernel]] void ternary_g(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
@@ -88,7 +87,7 @@ template <typename T, typename Op, int N = 1>
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd(
|
||||
auto idx = elem_to_loc_3_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z},
|
||||
shape,
|
||||
a_strides,
|
||||
@@ -96,11 +95,10 @@ template <typename T, typename Op, int N = 1>
|
||||
c_strides,
|
||||
ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
auto c_xstride = c_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
IdxT c_xstride = c_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
idx.x += a_xstride;
|
||||
|
||||
@@ -8,13 +8,16 @@
|
||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
|
||||
@@ -18,7 +18,12 @@ template <typename T, typename U, typename Op>
|
||||
out[offset] = Op()(in[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
[[kernel]] void unary_g(
|
||||
device const T* in,
|
||||
device U* out,
|
||||
@@ -27,12 +32,11 @@ template <typename T, typename U, typename Op, int N = 1>
|
||||
device const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx =
|
||||
elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
||||
auto idx = elem_to_loc<size_t, IdxT>(
|
||||
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
||||
auto xshape = in_shape[ndim - 1];
|
||||
auto xstride = in_strides[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
IdxT xstride = in_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
out[out_idx++] = Op()(in[idx]);
|
||||
idx += xstride;
|
||||
|
||||
@@ -5,11 +5,13 @@
|
||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
|
||||
instantiate_kernel( \
|
||||
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
||||
@@ -89,44 +89,45 @@ struct Limits<complex64_t> {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with generic dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
constant const StrideT* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
stride_t elem,
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
StrideT elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
constant const StrideT* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
constant const StrideT* strides,
|
||||
int ndim) {
|
||||
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
IdxT loc =
|
||||
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
loc += (elem.z % shape[d]) * IdxT(strides[d]);
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
@@ -135,61 +136,65 @@ METAL_FUNC stride_t elem_to_loc(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with fixed N dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
|
||||
return elem * stride;
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
|
||||
return elem * IdxT(stride);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t
|
||||
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
|
||||
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t
|
||||
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
|
||||
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
|
||||
elem.z * IdxT(strides[0]);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with generic dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC ulong2 elem_to_loc_2_nd(
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* a_strides,
|
||||
constant const stride_t* b_strides,
|
||||
constant const StrideT* a_strides,
|
||||
constant const StrideT* b_strides,
|
||||
int ndim) {
|
||||
ulong2 loc = {
|
||||
ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
vec<IdxT, 2> loc = {
|
||||
IdxT(
|
||||
elem.x * IdxT(a_strides[ndim - 1]) +
|
||||
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
|
||||
IdxT(
|
||||
elem.x * IdxT(b_strides[ndim - 1]) +
|
||||
elem.y * IdxT(b_strides[ndim - 2]))};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.x += l * IdxT(a_strides[d]);
|
||||
loc.y += l * IdxT(b_strides[d]);
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
METAL_FUNC ulong3 elem_to_loc_3_nd(
|
||||
template <typename IdxT = size_t>
|
||||
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
int ndim) {
|
||||
ulong3 loc = {
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
|
||||
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
|
||||
vec<IdxT, 3> loc = {
|
||||
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
|
||||
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
|
||||
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
loc.x += l * IdxT(a_strides[d]);
|
||||
loc.y += l * IdxT(b_strides[d]);
|
||||
loc.z += l * IdxT(c_strides[d]);
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
|
||||
Reference in New Issue
Block a user