Faster indexing math in a few kernels (#1589)

* wip: faster compiled kernels

* faster general unary with uint specialization

* index type in compiled, unary, binary, ternary, copy

* fix jit

* jit fix

* specialize gather + scatter

* nit in docs
This commit is contained in:
Awni Hannun
2024-11-18 19:52:00 -08:00
committed by GitHub
parent bf481e8e5d
commit 2419edd5b2
25 changed files with 630 additions and 484 deletions

View File

@@ -77,12 +77,12 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
@@ -91,13 +91,13 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
@@ -106,14 +106,18 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@@ -124,13 +128,12 @@ template <typename T, typename U, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;

View File

@@ -9,18 +9,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@@ -99,14 +99,14 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
@@ -116,15 +116,15 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
@@ -134,16 +134,20 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@@ -155,13 +159,12 @@ template <typename T, typename U, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];

View File

@@ -7,18 +7,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \

View File

@@ -42,36 +42,36 @@ template <typename T, typename U>
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
IdxT dst_idx =
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1>
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -80,17 +80,16 @@ template <typename T, typename U, int N = 1>
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(
auto src_idx = elem_to_loc<int64_t, IdxT>(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
int64_t dst_idx =
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
IdxT dst_idx =
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
return;
}
auto xshape = src_shape[ndim - 1];
int64_t dst_idx =
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
@@ -105,36 +104,36 @@ template <typename T, typename U>
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1>
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -143,7 +142,7 @@ template <typename T, typename U, int N = 1>
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
{N * index.x, index.y, index.z},
src_shape,
src_strides,
@@ -153,8 +152,8 @@ template <typename T, typename U, int N = 1>
dst[idx.y] = static_cast<U>(src[idx.x]);
return;
}
auto src_xstride = src_strides[ndim - 1];
auto dst_xstride = dst_strides[ndim - 1];
IdxT src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]);

View File

@@ -4,19 +4,25 @@
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -4,7 +4,7 @@
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -16,18 +16,18 @@ METAL_FUNC void gather_impl(
const thread Indices<IdxT, NIDX>& indices,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
size_t src_idx = 0;
LocT src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
LocT idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
} else {
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
idx_loc += indices.row_contiguous[i]
? index.y
: elem_to_loc(
: elem_to_loc<size_t, LocT>(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
@@ -35,17 +35,17 @@ METAL_FUNC void gather_impl(
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
}
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
auto src_offset =
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.z;
LocT out_idx = index.z;
if (IDX_NDIM == 1) {
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
} else if (IDX_NDIM >= 2) {
out_idx +=
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
}
out[out_idx] = src[src_offset + src_idx];
}

View File

@@ -14,7 +14,7 @@ struct Indices {
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {

View File

@@ -10,7 +10,8 @@ template <
typename Op,
int NIDX,
bool UPD_ROW_CONTIG,
int NWORK>
int NWORK,
typename LocT>
METAL_FUNC void scatter_impl(
const device T* updates,
device mlx_atomic<T>* out,
@@ -28,29 +29,31 @@ METAL_FUNC void scatter_impl(
Op op;
auto ind_idx = gid.y * NWORK;
size_t out_offset = 0;
LocT out_offset = 0;
if (upd_size > 1) {
out_offset =
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
out_offset = elem_to_loc<size_t, LocT>(
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
}
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
size_t out_idx = out_offset;
LocT out_idx = out_offset;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = indices.row_contiguous[i]
? ind_idx
: elem_to_loc(
: elem_to_loc<size_t, LocT>(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
out_idx +=
static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
}
auto upd_idx = ind_idx * upd_size + gid.x;
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
if constexpr (!UPD_ROW_CONTIG) {
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
upd_idx =
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
}
op.atomic_update(out, updates[upd_idx], out_idx);
}

View File

@@ -32,13 +32,13 @@ template <typename T, typename Op>
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_strides);
auto b_idx = elem_to_loc_1(index, b_strides);
auto c_idx = elem_to_loc_1(index, c_strides);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_strides);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_strides);
auto c_idx = elem_to_loc_1<size_t, uint>(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
template <typename T, typename Op, typename IdxT = size_t>
[[kernel]] void ternary_g_nd2(
device const bool* a,
device const T* b,
@@ -49,14 +49,14 @@ template <typename T, typename Op>
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2(index, c_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
template <typename T, typename Op, typename IdxT = size_t>
[[kernel]] void ternary_g_nd3(
device const bool* a,
device const T* b,
@@ -67,15 +67,14 @@ template <typename T, typename Op>
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, int N = 1>
template <typename T, typename Op, int N = 1, typename IdxT = size_t>
[[kernel]] void ternary_g(
device const bool* a,
device const T* b,
@@ -88,7 +87,7 @@ template <typename T, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd(
auto idx = elem_to_loc_3_nd<IdxT>(
{N * index.x, index.y, index.z},
shape,
a_strides,
@@ -96,11 +95,10 @@ template <typename T, typename Op, int N = 1>
c_strides,
ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
auto c_xstride = c_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
IdxT c_xstride = c_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
idx.x += a_xstride;

View File

@@ -8,13 +8,16 @@
#include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \

View File

@@ -18,7 +18,12 @@ template <typename T, typename U, typename Op>
out[offset] = Op()(in[offset]);
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void unary_g(
device const T* in,
device U* out,
@@ -27,12 +32,11 @@ template <typename T, typename U, typename Op, int N = 1>
device const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
auto idx = elem_to_loc<size_t, IdxT>(
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
auto xshape = in_shape[ndim - 1];
auto xstride = in_strides[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
IdxT xstride = in_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
out[out_idx++] = Op()(in[idx]);
idx += xstride;

View File

@@ -5,11 +5,13 @@
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
instantiate_kernel( \
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)

View File

@@ -89,44 +89,45 @@ struct Limits<complex64_t> {
///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
uint elem,
constant const int* shape,
constant const stride_t* strides,
constant const StrideT* strides,
int ndim) {
stride_t loc = 0;
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
StrideT elem,
constant const int* shape,
constant const stride_t* strides,
constant const StrideT* strides,
int ndim) {
stride_t loc = 0;
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
// Non templated version to handle arbitrary dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
uint3 elem,
constant const int* shape,
constant const stride_t* strides,
constant const StrideT* strides,
int ndim) {
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
IdxT loc =
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
loc += (elem.z % shape[d]) * IdxT(strides[d]);
elem.z /= shape[d];
}
return loc;
@@ -135,61 +136,65 @@ METAL_FUNC stride_t elem_to_loc(
///////////////////////////////////////////////////////////////////////////////
// Single Array with fixed N dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
return elem * stride;
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
return elem * IdxT(stride);
}
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
}
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
elem.z * IdxT(strides[0]);
}
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims
template <typename stride_t>
METAL_FUNC ulong2 elem_to_loc_2_nd(
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const stride_t* a_strides,
constant const stride_t* b_strides,
constant const StrideT* a_strides,
constant const StrideT* b_strides,
int ndim) {
ulong2 loc = {
ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
vec<IdxT, 2> loc = {
IdxT(
elem.x * IdxT(a_strides[ndim - 1]) +
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
IdxT(
elem.x * IdxT(b_strides[ndim - 1]) +
elem.y * IdxT(b_strides[ndim - 2]))};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.x += l * IdxT(a_strides[d]);
loc.y += l * IdxT(b_strides[d]);
elem.z /= shape[d];
}
return loc;
}
METAL_FUNC ulong3 elem_to_loc_3_nd(
template <typename IdxT = size_t>
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
int ndim) {
ulong3 loc = {
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
vec<IdxT, 3> loc = {
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
loc.x += l * IdxT(a_strides[d]);
loc.y += l * IdxT(b_strides[d]);
loc.z += l * IdxT(c_strides[d]);
elem.z /= shape[d];
}
return loc;