Refactor reductions and fix scatter atomics for large sizes (#1300)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-08-22 16:03:31 -07:00 committed by GitHub
parent f9e00efe31
commit 98b6ce3460
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1584 additions and 1235 deletions

View File

@ -49,7 +49,7 @@ struct ReductionPlan {
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Helper for the ndimensional strided loop
// Should this be in utils?

View File

@ -19,7 +19,7 @@ std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
return std::make_pair(shape, strides);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
@ -41,6 +41,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
}
}
// Remove singleton axes from the plan
for (int i = shape.size() - 1; i >= 0; i--) {
if (shape[i] == 1) {
shape.erase(shape.begin() + i);
strides.erase(strides.begin() + i);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
@ -63,10 +71,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
return a.second > b.second;
bool a_is_zero = a.second == 0;
bool b_is_zero = b.second == 0;
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
@ -98,16 +110,33 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
if (x.strides()[i] != size) {
size_t stride_i = x.strides()[i];
int shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;
}
have_expand = true;
break;
}
size *= x.shape(i);
if (stride_i != size && shape_i != 1) {
break;
}
size *= shape_i;
}
if (size >= strides.back()) {
// In the case of an expanded dimension we are being conservative and
// require the smallest reduction stride to be smaller than the maximum row
// contiguous size. The reason is that we can't easily know if the reduced
// axis is before or after an expanded dimension.
if (size > strides.back() || (size == strides.back() && !have_expand)) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}

View File

@ -104,6 +104,33 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
std::vector<array>{std::forward<Arrays>(xs)...});
}
// The single array version of the above.
inline std::tuple<std::vector<int>, std::vector<size_t>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::vector<int> collapsed_shape;
std::vector<size_t> collapsed_strides;
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
collapsed_strides.push_back(strides[0]);
for (int i = 1; i < shape.size(); i++) {
if (strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
std::numeric_limits<int>::max()) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
collapsed_shape.back() *= shape[i];
collapsed_strides.back() = strides[i];
}
}
}
return std::make_tuple(collapsed_shape, collapsed_strides);
}
template <typename stride_t>
inline auto check_contiguity(
const std::vector<int>& shape,

View File

@ -37,13 +37,13 @@ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@ -51,13 +51,15 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_and_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
METAL_FUNC void mlx_atomic_fetch_or_explicit(
device mlx_atomic<T>* object,
T val,
size_t offset) {
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@ -65,7 +67,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_min_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@ -73,7 +75,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_max_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@ -81,7 +83,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_add_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@ -89,7 +91,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
T expected = mlx_atomic_load_explicit(object, offset);
while (!mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val * expected, offset)) {
@ -101,7 +103,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread T* expected,
T val,
uint offset) {
size_t offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
@ -115,7 +117,7 @@ template <>
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
device mlx_atomic<float>* object,
float val,
uint offset) {
size_t offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val < expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
@ -130,7 +132,7 @@ template <>
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
device mlx_atomic<float>* object,
float val,
uint offset) {
size_t offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val > expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
@ -157,7 +159,7 @@ union uint_or_packed {
template <typename T, typename Op>
struct mlx_atomic_update_helper {
uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
uint operator()(uint_or_packed<T> init, T update, size_t elem_offset) {
Op op;
init.val[elem_offset] = op(update, init.val[elem_offset]);
return init.bits;
@ -168,9 +170,9 @@ template <typename T, typename Op>
METAL_FUNC void mlx_atomic_update_and_store(
device mlx_atomic<T>* object,
T update,
uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
size_t offset) {
size_t pack_offset = offset / packing_size<T>;
size_t elem_offset = offset % packing_size<T>;
mlx_atomic_update_helper<T, Op> helper;
uint_or_packed<T> expected;
@ -251,9 +253,9 @@ struct __Min {
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
uint pack_offset = offset / sizeof(T);
uint elem_offset = offset % sizeof(T);
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
size_t pack_offset = offset / sizeof(T);
size_t elem_offset = offset % sizeof(T);
uint_or_packed<T> packed_val;
packed_val.bits =
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
@ -262,7 +264,7 @@ mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
}
@ -270,9 +272,9 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_and_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
size_t offset) {
size_t pack_offset = offset / packing_size<T>;
size_t elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = __UINT32_MAX__;
identity.val[elem_offset] = val;
@ -282,10 +284,12 @@ METAL_FUNC void mlx_atomic_fetch_and_explicit(
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
METAL_FUNC void mlx_atomic_fetch_or_explicit(
device mlx_atomic<T>* object,
T val,
size_t offset) {
size_t pack_offset = offset / packing_size<T>;
size_t elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = 0;
identity.val[elem_offset] = val;
@ -298,7 +302,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_min_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
}
@ -306,7 +310,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_max_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
}
@ -314,7 +318,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_add_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
}
@ -322,7 +326,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
}
@ -331,7 +335,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread uint* expected,
uint val,
uint offset) {
size_t offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,

View File

@ -23,6 +23,8 @@ struct complex64_t {
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
constexpr complex64_t() : real(0), imag(0) {};
constexpr complex64_t() threadgroup : real(0), imag(0) {};
// Conversions to complex64_t
template <

View File

@ -9,7 +9,8 @@
#endif
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int REDUCE_N_READS = 4;
static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;

View File

@ -28,7 +28,8 @@
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) \
inst_f(name, uint64, uint64_t, op)
inst_f(name, uint64, uint64_t, op) \
inst_f(name, complex64, complex64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
@ -97,40 +98,24 @@ instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
device otype* out [[buffer(1)]], \
const constant size_t& in_size [[buffer(2)]], \
const constant size_t& row_size [[buffer(3)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
@ -138,153 +123,143 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
template [[host_name("colSmall" #dim "_reduce_" #name)]] \
[[kernel]] void col_reduce_small<itype, otype, op, dim>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
const constant int* reduce_shape [[buffer(7)]], \
const constant size_t* reduce_strides [[buffer(8)]], \
const constant int& reduce_ndim [[buffer(9)]], \
const constant size_t& non_col_reductions [[buffer(10)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[thread_position_in_grid]], \
uint3 tsize [[threads_per_grid]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
template [[host_name("colLooped" #dim "_" #bm "_" #bn "_reduce_" #name)]] \
[[kernel]] void col_reduce_looped<itype, otype, op, dim, bm, bn>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
const constant int* reduce_shape [[buffer(7)]], \
const constant size_t* reduce_strides [[buffer(8)]], \
const constant int& reduce_ndim [[buffer(9)]], \
const constant size_t& non_col_reductions [[buffer(10)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
#define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \
instantiate_col_reduce_small(name, itype, otype, op, 1) \
instantiate_col_reduce_small(name, itype, otype, op, 2) \
instantiate_col_reduce_small(name, itype, otype, op, 3) \
instantiate_col_reduce_small(name, itype, otype, op, 4) \
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
instantiate_col_reduce_looped(name, itype, otype, op, 4)
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_general(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
#define instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
template [[host_name("rowSmall" #dim "_reduce_" #name)]] [[kernel]] void \
row_reduce_small<itype, otype, op, dim>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& row_size [[buffer(2)]], \
const constant size_t& non_row_reductions [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
const constant int* reduce_shape [[buffer(7)]], \
const constant size_t* reduce_strides [[buffer(8)]], \
const constant int& reduce_ndim [[buffer(9)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 tid [[thread_position_in_grid]], \
uint3 tsize [[threads_per_grid]]);
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
template \
[[host_name("rowLooped" #dim "_reduce_" #name)]] [[kernel]] void \
row_reduce_looped<itype, otype, op, dim>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& row_size [[buffer(2)]], \
const constant size_t& non_row_reductions [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
const constant int* reduce_shape [[buffer(7)]], \
const constant size_t* reduce_strides [[buffer(8)]], \
const constant int& reduce_ndim [[buffer(9)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \
instantiate_row_reduce_small(name, itype, otype, op, 2) \
instantiate_row_reduce_small(name, itype, otype, op, 3) \
instantiate_row_reduce_small(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
template \
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
row_reduce_general<itype, otype, op>( \
[[host_name("rowSimple_reduce_" #name)]] [[kernel]] void \
row_reduce_simple<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template \
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)

View File

@ -5,6 +5,20 @@
#include <metal_atomic>
#include <metal_simdgroup>
#define DEFINE_SIMD_REDUCE() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_reduce(T val) { \
return simd_reduce_impl(val); \
} \
\
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
T simd_reduce(T val) { \
for (short i = simd_size / 2; i > 0; i /= 2) { \
val = operator()(val, simd_shuffle_down(val, i)); \
} \
return val; \
}
static constant constexpr const uint8_t simd_size = 32;
union bool4_or_uint {
@ -14,14 +28,16 @@ union bool4_or_uint {
struct None {
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_store_explicit(out, val, offset);
}
};
template <typename U = bool>
struct And {
bool simd_reduce(bool val) {
DEFINE_SIMD_REDUCE()
bool simd_reduce_impl(bool val) {
return simd_all(val);
}
@ -31,7 +47,7 @@ struct And {
device mlx_atomic<unsigned int>* out,
bool val,
int elem_idx,
int offset = 0) {
size_t offset = 0) {
if (!val) {
bool4_or_uint update;
update.b = {true, true, true, true};
@ -40,7 +56,8 @@ struct And {
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
void
atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
if (!val) {
mlx_atomic_store_explicit(out, val, offset);
}
@ -59,7 +76,9 @@ struct And {
template <typename U = bool>
struct Or {
bool simd_reduce(bool val) {
DEFINE_SIMD_REDUCE()
bool simd_reduce_impl(bool val) {
return simd_any(val);
}
@ -68,8 +87,8 @@ struct Or {
void atomic_update(
device mlx_atomic<unsigned int>* out,
bool val,
uint elem_idx,
uint offset = 0) {
int elem_idx,
size_t offset = 0) {
if (val) {
bool4_or_uint update;
update.b = {false, false, false, false};
@ -78,7 +97,8 @@ struct Or {
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
void
atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
if (val) {
mlx_atomic_store_explicit(out, val, offset);
}
@ -97,15 +117,17 @@ struct Or {
template <typename U>
struct Sum {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce(T val) {
T simd_reduce_impl(T val) {
return simd_sum(val);
}
static constexpr constant U init = U(0);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_fetch_add_explicit(out, val, offset);
}
@ -117,15 +139,17 @@ struct Sum {
template <typename U>
struct Prod {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce(T val) {
T simd_reduce_impl(T val) {
return simd_product(val);
}
static constexpr constant U init = U(1);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_fetch_mul_explicit(out, val, offset);
}
@ -137,15 +161,17 @@ struct Prod {
template <typename U>
struct Min {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce(T val) {
T simd_reduce_impl(T val) {
return simd_min(val);
}
static constexpr constant U init = Limits<U>::max;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_fetch_min_explicit(out, val, offset);
}
@ -157,15 +183,17 @@ struct Min {
template <typename U>
struct Max {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce(T val) {
T simd_reduce_impl(T val) {
return simd_max(val);
}
static constexpr constant U init = Limits<U>::min;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_fetch_max_explicit(out, val, offset);
}

View File

@ -1,135 +1,61 @@
// Copyright © 2023-2024 Apple Inc.
///////////////////////////////////////////////////////////////////////////////
// All reduce helper
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_all_reduce(
const device T* in,
const device size_t& in_size,
uint gid,
uint grid_size) {
Op op;
U total_val = Op::init;
if (gid * N_READS < in_size) {
in += gid * N_READS;
int r = 0;
for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for (int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for (int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for (int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for (int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
}
return total_val;
}
///////////////////////////////////////////////////////////////////////////////
// All reduce kernel
///////////////////////////////////////////////////////////////////////////////
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
device U* out [[buffer(1)]],
const constant size_t& in_size [[buffer(2)]],
const constant size_t& row_size [[buffer(3)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
threadgroup U shared_vals[simd_size];
U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
U total = Op::init;
int64_t start_idx = gid.y * row_size;
int64_t actual_row =
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
int64_t blocks = actual_row / (lsize.x * N_READS);
int extra = actual_row - blocks * (lsize.x * N_READS);
extra -= lid.x * N_READS;
start_idx += lid.x * N_READS;
in += start_idx;
if (extra >= N_READS) {
blocks++;
extra = 0;
}
for (int64_t b = 0; b < blocks; b++) {
for (int i = 0; i < N_READS; i++) {
total = op(static_cast<U>(in[i]), total);
}
in += lsize.x * N_READS;
}
if (extra > 0) {
for (int i = 0; i < extra; i++) {
total = op(static_cast<U>(in[i]), total);
}
}
// Reduction within simd group
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
total = op.simd_reduce(total);
if (simd_per_group > 1) {
if (simd_lane_id == 0) {
shared_vals[simd_group_id] = total;
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init;
total = op.simd_reduce(total);
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = op.simd_reduce(total_val);
// Reduction across threadgroups
if (lid == 0) {
op.atomic_update(out, total_val);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64
// types)
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Reduction across threadgroups
if (lid == 0) {
out[thread_group_id] = total_val;
if (lid.x == 0) {
out[gid.y] = total;
}
}

View File

@ -1,165 +1,298 @@
// Copyright © 2023-2024 Apple Inc.
///////////////////////////////////////////////////////////////////////////////
// Small column reduce kernel
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op>
template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]) {
// Appease the compiler
(void)out_size;
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[thread_position_in_grid]],
uint3 tsize [[threads_per_grid]]) {
Op op;
U total_val = Op::init;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
auto out_idx = tid;
// Case 1:
// reduction_stride is small, reduction_size is small and non_col_reductions
// is small. Each thread computes reduction_stride outputs.
if (reduction_size * non_col_reductions < 64) {
U totals[31];
for (int i = 0; i < 31; i++) {
totals[i] = Op::init;
}
in += elem_to_loc(
out_idx,
shape + non_col_ndim,
strides + non_col_ndim,
ndim - non_col_ndim);
short stride = reduction_stride;
short size = reduction_size;
short blocks = stride / N_READS;
short extra = stride - blocks * N_READS;
for (uint i = 0; i < non_col_reductions; i++) {
size_t in_idx =
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim);
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
U val = static_cast<U>(in[in_idx]);
total_val = op(total_val, val);
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
for (short i = 0; i < size; i++) {
for (short j = 0; j < blocks; j++) {
for (short k = 0; k < N_READS; k++) {
totals[j * N_READS + k] =
op(totals[j * N_READS + k],
static_cast<U>(row[i * stride + j * N_READS + k]));
}
}
for (short k = 0; k < extra; k++) {
totals[blocks * N_READS + k] =
op(totals[blocks * N_READS + k],
static_cast<U>(row[i * stride + blocks * N_READS + k]));
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride;
for (short j = 0; j < stride; j++) {
out[j] = totals[j];
}
}
out[out_idx] = total_val;
}
///////////////////////////////////////////////////////////////////////////////
// Column reduce helper
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
METAL_FUNC U _contiguous_strided_reduce(
const device T* in,
threadgroup U* local_data,
uint in_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
uint2 lsize) {
Op op;
U total_val = Op::init;
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
uint offset = base_offset + r;
total_val =
op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
}
local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
U val = Op::init;
if (lid.y == 0) {
// Perform reduction across columns in thread group
for (uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
// Case 2:
// Reduction stride is small but everything else can be big. We loop both
// across reduction size and non_col_reductions. Each simdgroup produces
// N_READS outputs.
else {
threadgroup U shared_vals[1024];
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
}
return val;
}
short stride = reduction_stride;
short lid = simd_group_id * simd_size + simd_lane_id;
short2 tile((stride + N_READS - 1) / N_READS, 32);
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
short sm_stride = tile.x * N_READS;
bool safe = offset.x + N_READS <= stride;
///////////////////////////////////////////////////////////////////////////////
// Column reduce kernel
///////////////////////////////////////////////////////////////////////////////
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general(
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
// Read cooperatively and contiguously and aggregate the partial results.
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y, reduce_shape, reduce_strides);
for (size_t r = offset.y; r < total; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
Op op;
if (out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
if (safe) {
for (int i = 0; i < N_READS; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
// Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
loop.next(simd_size, reduce_shape, reduce_strides);
}
// Each thread holds N_READS partial results but the simdgroups are not
// aligned to do the reduction across the simdgroup so we write our results
// in the shared memory and read them back according to the simdgroup.
for (int i = 0; i < N_READS; i++) {
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_READS; i++) {
totals[i] = op.simd_reduce(
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
}
// Write the output.
if (simd_lane_id == 0) {
short column = simd_group_id * N_READS;
out += out_idx * reduction_stride + column;
if (column + N_READS <= stride) {
for (int i = 0; i < N_READS; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; column + i < stride; i++) {
out[i] = totals[i];
}
}
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
/**
* Our approach is the following simple looped approach:
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
* 2. Load a tile BM, BN in shared memory.
* 3. Add the values from shared memory to the current running totals.
* Neighboring threads access different rows (transposed acces).
* 4. Move ahead to the next tile until the M axis is exhausted.
* 5. Move ahead to the next non column reduction
* 6. Simd reduce the running totals
* 7. Write them to the output
*
* The kernel becomes verbose because we support all kinds of OOB checks. For
* instance if we choose that reduction_stride must be larger than BN then we
* can get rid of half the kernel.
*/
template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int BM = 8,
int BN = 128>
[[kernel]] void col_reduce_looped(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_simdgroups = 4;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;
if (out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
threadgroup U shared_vals[BN * BM];
U totals[n_reads];
looped_elem_to_loc<NDIMS> loop;
const device T* row;
// Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}
short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y, reduce_shape, reduce_strides);
for (size_t r = offset.y; r < total; r += BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(BM, reduce_shape, reduce_strides);
}
// We can use a simd reduction to accumulate across BM so each thread writes
// the partial output to SM and then each simdgroup does BN / n_simdgroups
// accumulations.
if (BM == 32) {
constexpr int n_outputs = BN / n_simdgroups;
static_assert(
BM != 32 || n_outputs == n_reads,
"The tile should be selected such that n_outputs == n_reads");
for (int i = 0; i < n_reads; i++) {
shared_vals[offset.y * BN + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
for (int i = 0; i < n_outputs; i++) {
totals[i] =
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
}
// Write the output.
if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x;
out += out_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; out_column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}
// Each thread holds n_reads partial results. We write them all out to shared
// memory and threads with offset.y == 0 aggregate the columns and write the
// outputs.
else {
short x_block = offset.x / n_reads;
for (int i = 0; i < n_reads; i++) {
shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (offset.y == 0) {
for (int i = 0; i < n_reads; i++) {
for (int j = 1; j < BM; j++) {
totals[i] =
op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
}
}
}
// Write the output.
if (offset.y == 0) {
out += out_idx * reduction_stride + column;
if (safe) {
for (int i = 0; i < n_reads; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}
}

View File

@ -1,287 +1,366 @@
// Copyright © 2023-2024 Apple Inc.
///////////////////////////////////////////////////////////////////////////////
// Small row reductions
///////////////////////////////////////////////////////////////////////////////
// Row reduction utilities
// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup
// - `threadgroup_reduce` collaborative reduction in the threadgroup such that
// lid.x == 0 holds the reduced value
// - `thread_reduce` simple loop and reduce the row
// Each thread reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]) {
/**
* The thread group collaboratively reduces across the rows with bounds
* checking. In the end each thread holds a part of the reduction.
*/
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce(
thread U totals[N_WRITES],
const device T* inputs[N_WRITES],
int blocks,
int extra,
uint lsize_x,
uint lid_x) {
Op op;
uint out_idx = lid;
if (out_idx >= out_size) {
return;
// Set up the accumulator registers
for (int i = 0; i < N_WRITES; i++) {
totals[i] = Op::init;
}
U total_val = Op::init;
// Loop over the reduction size within thread group
for (int i = 0; i < blocks; i++) {
for (int j = 0; j < N_WRITES; j++) {
for (int i = 0; i < N_READS; i++) {
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
}
for (short r = 0; r < short(non_row_reductions); r++) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T* in_row = in + in_idx;
for (short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
inputs[j] += lsize_x * N_READS;
}
}
out[out_idx] = total_val;
}
// Each simdgroup reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_med(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
uint out_idx = simd_per_group * tid + simd_group_id;
if (out_idx >= out_size) {
return;
}
U total_val = Op::init;
if (short(non_row_reductions) == 1) {
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
const device T* in_row = in + in_idx;
for (short i = simd_lane_id; i < short(reduction_size); i += 32) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
else if (short(non_row_reductions) >= 32) {
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T* in_row = in + in_idx;
for (short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
// Separate case for the last set as we close the reduction size
int index = lid_x * N_READS;
if (index + N_READS <= extra) {
for (int j = 0; j < N_WRITES; j++) {
for (int i = 0; i < N_READS; i++) {
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
}
}
}
else {
const short n_reductions =
short(reduction_size) * short(non_row_reductions);
const short reductions_per_thread =
(n_reductions + simd_size - 1) / simd_size;
const short r_st = simd_lane_id / reductions_per_thread;
const short r_ed = short(non_row_reductions);
const short r_jump = simd_size / reductions_per_thread;
const short i_st = simd_lane_id % reductions_per_thread;
const short i_ed = short(reduction_size);
const short i_jump = reductions_per_thread;
if (r_st < r_jump) {
for (short r = r_st; r < r_ed; r += r_jump) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T* in_row = in + in_idx;
for (short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
} else {
for (int j = 0; j < N_WRITES; j++) {
for (int i = 0; index + i < extra; i++) {
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
}
}
}
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
out[out_idx] = total_val;
}
}
///////////////////////////////////////////////////////////////////////////////
// Large row reductions
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_row_reduce(
/**
* Consecutive rows in a contiguous array.
*/
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce(
thread U totals[N_WRITES],
const device T* in,
const constant size_t& reduction_size,
const constant size_t& out_size,
int blocks,
int extra,
uint lsize_x,
uint lid_x) {
// Set up the input pointers
const device T* inputs[N_WRITES];
inputs[0] = in + lid_x * N_READS;
for (int i = 1; i < N_READS; i++) {
inputs[i] = inputs[i - 1] + reduction_size;
}
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, inputs, blocks, extra, lsize_x, lid_x);
}
/**
* Consecutive rows in an arbitrarily ordered array.
*/
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce(
thread U totals[N_WRITES],
const device T* in,
const size_t row_idx,
int blocks,
int extra,
const constant int* shape,
const constant size_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x,
uint2 tid) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid_x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) {
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for (int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize_x * N_READS;
uint lid_x) {
// Set up the input pointers
const device T* inputs[N_WRITES];
in += lid_x * N_READS;
for (int i = 0; i < N_READS; i++) {
inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim);
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
if (reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for (int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
return total_val;
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, inputs, blocks, extra, lsize_x, lid_x);
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_general(
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
/**
* Reduce within the threadgroup.
*/
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void threadgroup_reduce(
thread U totals[N_WRITES],
threadgroup U* shared_vals,
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
total_val = op.simd_reduce(total_val);
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
// Simdgroup first
for (int i = 0; i < N_WRITES; i++) {
totals[i] = op.simd_reduce(totals[i]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if multiple simd groups
if (reduction_size > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
total_val = op.simd_reduce(total_val);
}
// Update output
if (lid.x == 0) {
op.atomic_update(out, total_val, tid.x);
// Across simdgroups
if (simd_per_group > 1) {
if (simd_lane_id == 0) {
for (int i = 0; i < N_WRITES; i++) {
shared_vals[simd_group_id * N_WRITES + i] = totals[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
U values[N_WRITES];
for (int i = 0; i < N_WRITES; i++) {
values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
: op.init;
}
for (int i = 0; i < N_WRITES; i++) {
totals[i] = op.simd_reduce(values[i]);
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
METAL_FUNC void
thread_reduce(thread U& total, const device T* row, int blocks, int extra) {
Op op;
for (int i = 0; i < blocks; i++) {
U vals[N_READS];
for (int j = 0; j < N_READS; j++) {
vals[j] = row[j];
}
for (int j = 0; j < N_READS; j++) {
total = op(vals[j], total);
}
row += N_READS;
}
for (int i = 0; i < extra; i++) {
total = op(*row++, total);
}
}
// Reduction kernels
// - `row_reduce_small` depending on the non-row reductions and row size it
// either just loops over everything or a simd collaboratively reduces the
// non_row reductions. In the first case one thread is responsible for one
// output on the 2nd one simd is responsible for one output.
// - `row_reduce_simple` simple contiguous row reduction
// - `row_reduce_looped` simply loop and reduce each row for each non-row
// reduction. One threadgroup is responsible for one output.
template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& row_size [[buffer(2)]],
const constant size_t& non_row_reductions [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 tid [[thread_position_in_grid]],
uint3 tsize [[threads_per_grid]]) {
Op op;
U total_val = Op::init;
looped_elem_to_loc<NDIMS> loop;
// Precompute some row reduction numbers
const device T* row;
int blocks = row_size / N_READS;
int extra = row_size % N_READS;
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
// Simple loop over non_row_reductions and reduce the row in the thread.
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_row_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(reduce_shape, reduce_strides);
}
out[out_idx] = total_val;
} else {
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
// thread reduces every 32nd row and then a simple simd reduce.
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim);
loop.next(simd_lane_id, reduce_shape, reduce_strides);
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(simd_size, reduce_shape, reduce_strides);
}
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
out[out_idx] = total_val;
}
}
}
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
[[kernel]] void row_reduce_simple(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
threadgroup U shared_vals[simd_size * N_WRITES];
U totals[N_WRITES];
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
// Move to the row
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
if (out_idx + N_WRITES > out_size) {
out_idx = out_size - N_WRITES;
}
in += out_idx * reduction_size;
out += out_idx;
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread reduces across the row
int blocks = reduction_size / (lsize.x * N_READS);
int extra = reduction_size - blocks * (lsize.x * N_READS);
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if (ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
// Reduce across the threadgroup
threadgroup_reduce<T, U, Op, N_READS, N_WRITES>(
totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
// Write the output
if (lid.x == 0) {
for (int i = 0; i < N_WRITES; i++) {
out[i] = totals[i];
}
}
// Write row reduce output for threadgroup with 1st thread in thread group
}
template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_looped(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& row_size [[buffer(2)]],
const constant size_t& non_row_reductions [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U shared_vals[simd_size];
U total = Op::init;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
// needs a small refactor.
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
int blocks = row_size / (lsize.x * N_READS);
int extra = row_size - blocks * (lsize.x * N_READS);
for (size_t i = 0; i < non_row_reductions; i++) {
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
// Each thread reduces across the row
U row_total;
per_thread_row_reduce<T, U, Op, N_READS, 1>(
&row_total, &row, blocks, extra, lsize.x, lid.x);
// Aggregate across rows
total = op(total, row_total);
loop.next(reduce_shape, reduce_strides);
}
// Reduce across the threadgroup
threadgroup_reduce<T, U, Op, N_READS, 1>(
&total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
// Write the output
if (lid.x == 0) {
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
out[out_idx] = total;
}
}

View File

@ -18,7 +18,7 @@ METAL_FUNC void scatter_1d_index_impl(
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
size_t out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];

View File

@ -64,6 +64,16 @@ struct Limits<bool> {
static constexpr constant bool min = false;
};
template <>
struct Limits<complex64_t> {
static constexpr constant complex64_t max = complex64_t(
metal::numeric_limits<float>::infinity(),
metal::numeric_limits<float>::infinity());
static constexpr constant complex64_t min = complex64_t(
-metal::numeric_limits<float>::infinity(),
-metal::numeric_limits<float>::infinity());
};
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
@ -101,6 +111,34 @@ METAL_FUNC stride_t elem_to_loc(
return loc;
}
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
device const int* shape,
device const stride_t* strides,
int ndim) {
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
constant const int* shape,
constant const stride_t* strides,
int ndim) {
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
// Non templated version to handle arbitrary dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
@ -288,12 +326,87 @@ METAL_FUNC uint3 elem_to_loc_3_nd(
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Elem to loc in a loop utils
///////////////////////////////////////////////////////////////////////////////
template <int dim, typename offset_t = size_t>
struct looped_elem_to_loc {
looped_elem_to_loc<dim - 1, offset_t> inner_looper;
offset_t offset{0};
int index{0};
void next(const constant int* shape, const constant size_t* strides) {
index++;
offset += strides[dim - 1];
if (index >= shape[dim - 1]) {
index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset;
}
}
void next(int n, const constant int* shape, const constant size_t* strides) {
index += n;
offset += n * strides[dim - 1];
if (index >= shape[dim - 1]) {
int extra = index - shape[dim - 1];
index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset;
if (extra > 0) {
next(extra, shape, strides);
}
}
}
offset_t
location(offset_t, const constant int*, const constant size_t*, int) {
return offset;
}
};
template <typename offset_t>
struct looped_elem_to_loc<1, offset_t> {
offset_t offset{0};
void next(const constant int*, const constant size_t* strides) {
offset += strides[0];
}
void next(int n, const constant int*, const constant size_t* strides) {
offset += n * strides[0];
}
offset_t
location(offset_t, const constant int*, const constant size_t*, int) {
return offset;
}
};
template <typename offset_t>
struct looped_elem_to_loc<0, offset_t> {
void next(const constant int*, const constant size_t*) {}
void next(int, const constant int*, const constant size_t*) {}
offset_t location(
offset_t idx,
const constant int* shape,
const constant size_t* strides,
int ndim) {
return elem_to_loc(idx, shape, strides, ndim);
}
};
///////////////////////////////////////////////////////////////////////////////
// Calculation utils
///////////////////////////////////////////////////////////////////////////////
/** Compute ceil((float)N/(float)M) */
inline size_t ceildiv(size_t N, size_t M) {
template <typename T, typename U>
inline T ceildiv(T N, U M) {
return (N + M - 1) / M;
}
@ -339,3 +452,8 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
return complex64_t(
simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
}

View File

@ -135,8 +135,8 @@ void RMSNormVJP::eval_gpu(
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
// Allocate a temporary to store the gradients for w and initialize the
// gradient accumulator to 0.
// Allocate the gradient accumulator gw and a temporary to store the
// gradients before they are accumulated.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
@ -146,11 +146,7 @@ void RMSNormVJP::eval_gpu(
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
}
copies.push_back(gw_temp);
{
array zero(0, gw.dtype());
copy_gpu(zero, gw, CopyType::Scalar, s);
copies.push_back(std::move(zero));
}
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
const int simd_size = 32;
const int n_reads = RMS_N_READS;
@ -330,8 +326,8 @@ void LayerNormVJP::eval_gpu(
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
// Allocate a temporary to store the gradients for w and initialize the
// gradient accumulator to 0.
// Allocate a temporary to store the gradients for w and allocate the output
// gradient accumulators.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
@ -341,12 +337,8 @@ void LayerNormVJP::eval_gpu(
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
}
copies.push_back(gw_temp);
{
array zero(0, gw.dtype());
copy_gpu(zero, gw, CopyType::Scalar, s);
copy_gpu(zero, gb, CopyType::Scalar, s);
copies.push_back(std::move(zero));
}
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
// Finish with the gradient for b in case we had a b
auto& compute_encoder = d.get_command_encoder(s.index);

File diff suppressed because it is too large Load Diff

View File

@ -16,9 +16,11 @@ namespace mlx::core {
namespace {
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool>
compute_reduce_shape(
const std::vector<int>& axes,
const std::vector<int>& shape) {
bool is_noop = true;
std::set<int> axes_set;
auto ndim = shape.size();
for (auto ax : axes) {
@ -35,15 +37,18 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
throw std::invalid_argument("Duplicate axes detected in reduction.");
}
std::vector<int> out_shape;
std::vector<int> squeezed_shape;
for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]);
squeezed_shape.push_back(shape[i]);
} else {
out_shape.push_back(1);
}
is_noop &= (out_shape.back() == shape[i]);
}
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
return {out_shape, sorted_axes};
return {out_shape, sorted_axes, squeezed_shape, is_noop};
}
Dtype at_least_float(const Dtype& d) {
@ -1502,17 +1507,17 @@ array all(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
: array(
std::move(out_shape),
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@ -1536,17 +1541,17 @@ array any(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
: array(
std::move(out_shape),
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@ -1573,15 +1578,18 @@ array sum(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
auto out = array(
out_shape,
out_type,
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
auto out = (is_noop)
? astype(a, out_type, s)
: array(
std::move(out_shape),
out_type,
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@ -1715,14 +1723,17 @@ array prod(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@ -1749,17 +1760,17 @@ array max(
if (a.size() == 0) {
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
}
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@ -1789,14 +1800,17 @@ array min(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@ -1829,15 +1843,18 @@ array argmin(
throw std::invalid_argument(
"[argmin] Cannot argmin reduce zero size array.");
}
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
out_shape,
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
: array(
std::move(out_shape),
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@ -1862,15 +1879,18 @@ array argmax(
throw std::invalid_argument(
"[argmax] Cannot argmax reduce zero size array.");
}
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
out_shape,
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
: array(
std::move(out_shape),
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}

View File

@ -1763,7 +1763,7 @@ class TestOps(mlx_tests.MLXTestCase):
mat_t = mat.astype(t)
out = mx.cumsum(a_t, axis=-1)
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
self.assertTrue(mx.allclose(out, expected, rtol=0.02, atol=1e-3))
sizes = [1023, 1024, 1025, 2047, 2048, 2049]
for s in sizes:
a = mx.ones((s,), mx.int32)

View File

@ -43,6 +43,10 @@ class TestReduce(mlx_tests.MLXTestCase):
z_npy = np.sum(y_npy, axis=a) / 1000
z_mlx = mx.sum(y_mlx, axis=a) / 1000
mx.eval(z_mlx)
if not np.allclose(z_npy, np.array(z_mlx), atol=1e-4):
import pdb
pdb.set_trace()
self.assertTrue(
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
)
@ -57,6 +61,7 @@ class TestReduce(mlx_tests.MLXTestCase):
"uint32",
"int64",
"uint64",
"complex64",
]
float_dtypes = ["float32"]
@ -114,6 +119,15 @@ class TestReduce(mlx_tests.MLXTestCase):
b = getattr(np, op)(data)
self.assertEqual(a.item(), b)
def test_edge_case(self):
x = (mx.random.normal((100, 1, 100, 100)) * 128).astype(mx.int32)
x = x.transpose(0, 3, 1, 2)
y = x.sum((0, 2, 3))
mx.eval(y)
z = np.array(x).sum((0, 2, 3))
self.assertTrue(np.all(z == y))
if __name__ == "__main__":
unittest.main(failfast=True)