mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
f9e00efe31
commit
98b6ce3460
@ -49,7 +49,7 @@ struct ReductionPlan {
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
|
@ -19,7 +19,7 @@ std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
@ -41,6 +41,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
}
|
||||
}
|
||||
|
||||
// Remove singleton axes from the plan
|
||||
for (int i = shape.size() - 1; i >= 0; i--) {
|
||||
if (shape[i] == 1) {
|
||||
shape.erase(shape.begin() + i);
|
||||
strides.erase(strides.begin() + i);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
@ -63,10 +71,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
return a.second > b.second;
|
||||
bool a_is_zero = a.second == 0;
|
||||
bool b_is_zero = b.second == 0;
|
||||
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
@ -98,16 +110,33 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
if (x.strides()[i] != size) {
|
||||
|
||||
size_t stride_i = x.strides()[i];
|
||||
int shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
have_expand = true;
|
||||
break;
|
||||
}
|
||||
size *= x.shape(i);
|
||||
|
||||
if (stride_i != size && shape_i != 1) {
|
||||
break;
|
||||
}
|
||||
size *= shape_i;
|
||||
}
|
||||
if (size >= strides.back()) {
|
||||
// In the case of an expanded dimension we are being conservative and
|
||||
// require the smallest reduction stride to be smaller than the maximum row
|
||||
// contiguous size. The reason is that we can't easily know if the reduced
|
||||
// axis is before or after an expanded dimension.
|
||||
if (size > strides.back() || (size == strides.back() && !have_expand)) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
@ -104,6 +104,33 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
// The single array version of the above.
|
||||
inline std::tuple<std::vector<int>, std::vector<size_t>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::vector<int> collapsed_shape;
|
||||
std::vector<size_t> collapsed_strides;
|
||||
|
||||
if (shape.size() > 0) {
|
||||
collapsed_shape.push_back(shape[0]);
|
||||
collapsed_strides.push_back(strides[0]);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
if (strides[i] * shape[i] != collapsed_strides.back() ||
|
||||
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
|
||||
std::numeric_limits<int>::max()) {
|
||||
collapsed_shape.push_back(shape[i]);
|
||||
collapsed_strides.push_back(strides[i]);
|
||||
} else {
|
||||
collapsed_shape.back() *= shape[i];
|
||||
collapsed_strides.back() = strides[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(collapsed_shape, collapsed_strides);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
|
@ -37,13 +37,13 @@ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC T
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
|
||||
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
|
||||
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@ -51,13 +51,15 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_and_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||
METAL_FUNC void mlx_atomic_fetch_or_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
size_t offset) {
|
||||
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@ -65,7 +67,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_min_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@ -73,7 +75,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_max_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@ -81,7 +83,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_add_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@ -89,7 +91,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
T expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (!mlx_atomic_compare_exchange_weak_explicit(
|
||||
object, &expected, val * expected, offset)) {
|
||||
@ -101,7 +103,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
thread T* expected,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
return atomic_compare_exchange_weak_explicit(
|
||||
&(object[offset].val),
|
||||
expected,
|
||||
@ -115,7 +117,7 @@ template <>
|
||||
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
|
||||
device mlx_atomic<float>* object,
|
||||
float val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
float expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (val < expected) {
|
||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||
@ -130,7 +132,7 @@ template <>
|
||||
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
|
||||
device mlx_atomic<float>* object,
|
||||
float val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
float expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (val > expected) {
|
||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||
@ -157,7 +159,7 @@ union uint_or_packed {
|
||||
|
||||
template <typename T, typename Op>
|
||||
struct mlx_atomic_update_helper {
|
||||
uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
|
||||
uint operator()(uint_or_packed<T> init, T update, size_t elem_offset) {
|
||||
Op op;
|
||||
init.val[elem_offset] = op(update, init.val[elem_offset]);
|
||||
return init.bits;
|
||||
@ -168,9 +170,9 @@ template <typename T, typename Op>
|
||||
METAL_FUNC void mlx_atomic_update_and_store(
|
||||
device mlx_atomic<T>* object,
|
||||
T update,
|
||||
uint offset) {
|
||||
uint pack_offset = offset / packing_size<T>;
|
||||
uint elem_offset = offset % packing_size<T>;
|
||||
size_t offset) {
|
||||
size_t pack_offset = offset / packing_size<T>;
|
||||
size_t elem_offset = offset % packing_size<T>;
|
||||
|
||||
mlx_atomic_update_helper<T, Op> helper;
|
||||
uint_or_packed<T> expected;
|
||||
@ -251,9 +253,9 @@ struct __Min {
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC T
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
|
||||
uint pack_offset = offset / sizeof(T);
|
||||
uint elem_offset = offset % sizeof(T);
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
|
||||
size_t pack_offset = offset / sizeof(T);
|
||||
size_t elem_offset = offset % sizeof(T);
|
||||
uint_or_packed<T> packed_val;
|
||||
packed_val.bits =
|
||||
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
||||
@ -262,7 +264,7 @@ mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@ -270,9 +272,9 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_and_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
uint pack_offset = offset / packing_size<T>;
|
||||
uint elem_offset = offset % packing_size<T>;
|
||||
size_t offset) {
|
||||
size_t pack_offset = offset / packing_size<T>;
|
||||
size_t elem_offset = offset % packing_size<T>;
|
||||
uint_or_packed<T> identity;
|
||||
identity.bits = __UINT32_MAX__;
|
||||
identity.val[elem_offset] = val;
|
||||
@ -282,10 +284,12 @@ METAL_FUNC void mlx_atomic_fetch_and_explicit(
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||
uint pack_offset = offset / packing_size<T>;
|
||||
uint elem_offset = offset % packing_size<T>;
|
||||
METAL_FUNC void mlx_atomic_fetch_or_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
size_t offset) {
|
||||
size_t pack_offset = offset / packing_size<T>;
|
||||
size_t elem_offset = offset % packing_size<T>;
|
||||
uint_or_packed<T> identity;
|
||||
identity.bits = 0;
|
||||
identity.val[elem_offset] = val;
|
||||
@ -298,7 +302,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_min_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@ -306,7 +310,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_max_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@ -314,7 +318,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_add_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@ -322,7 +326,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@ -331,7 +335,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
thread uint* expected,
|
||||
uint val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
return atomic_compare_exchange_weak_explicit(
|
||||
&(object[offset].val),
|
||||
expected,
|
||||
|
@ -23,6 +23,8 @@ struct complex64_t {
|
||||
|
||||
// Constructors
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
|
||||
constexpr complex64_t() : real(0), imag(0) {};
|
||||
constexpr complex64_t() threadgroup : real(0), imag(0) {};
|
||||
|
||||
// Conversions to complex64_t
|
||||
template <
|
||||
|
@ -9,7 +9,8 @@
|
||||
#endif
|
||||
|
||||
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||
|
@ -28,7 +28,8 @@
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) \
|
||||
inst_f(name, uint64, uint64_t, op)
|
||||
inst_f(name, uint64, uint64_t, op) \
|
||||
inst_f(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
@ -97,40 +98,24 @@ instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce<itype, otype, op>( \
|
||||
template [[host_name("all_reduce_" #name)]] \
|
||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& in_size [[buffer(2)]], \
|
||||
const constant size_t& row_size [[buffer(3)]], \
|
||||
uint3 gid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
@ -138,153 +123,143 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
template [[host_name("colSmall" #dim "_reduce_" #name)]] \
|
||||
[[kernel]] void col_reduce_small<itype, otype, op, dim>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant int* shape [[buffer(4)]], \
|
||||
const constant size_t* strides [[buffer(5)]], \
|
||||
const constant int& ndim [[buffer(6)]], \
|
||||
const constant int* reduce_shape [[buffer(7)]], \
|
||||
const constant size_t* reduce_strides [[buffer(8)]], \
|
||||
const constant int& reduce_ndim [[buffer(9)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(10)]], \
|
||||
uint3 gid [[threadgroup_position_in_grid]], \
|
||||
uint3 gsize [[threadgroups_per_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[thread_position_in_grid]], \
|
||||
uint3 tsize [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
template [[host_name("colLooped" #dim "_" #bm "_" #bn "_reduce_" #name)]] \
|
||||
[[kernel]] void col_reduce_looped<itype, otype, op, dim, bm, bn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant int* shape [[buffer(4)]], \
|
||||
const constant size_t* strides [[buffer(5)]], \
|
||||
const constant int& ndim [[buffer(6)]], \
|
||||
const constant int* reduce_shape [[buffer(7)]], \
|
||||
const constant size_t* reduce_strides [[buffer(8)]], \
|
||||
const constant int& reduce_ndim [[buffer(9)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(10)]], \
|
||||
uint3 gid [[threadgroup_position_in_grid]], \
|
||||
uint3 gsize [[threadgroups_per_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 4)
|
||||
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint lid [[thread_position_in_grid]]); \
|
||||
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_med<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
template [[host_name("rowSmall" #dim "_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_small<itype, otype, op, dim>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& row_size [[buffer(2)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(3)]], \
|
||||
const constant int* shape [[buffer(4)]], \
|
||||
const constant size_t* strides [[buffer(5)]], \
|
||||
const constant int& ndim [[buffer(6)]], \
|
||||
const constant int* reduce_shape [[buffer(7)]], \
|
||||
const constant size_t* reduce_strides [[buffer(8)]], \
|
||||
const constant int& reduce_ndim [[buffer(9)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint3 gid [[threadgroup_position_in_grid]], \
|
||||
uint3 gsize [[threadgroups_per_grid]], \
|
||||
uint3 tid [[thread_position_in_grid]], \
|
||||
uint3 tsize [[threads_per_grid]]);
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
template \
|
||||
[[host_name("rowLooped" #dim "_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_looped<itype, otype, op, dim>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& row_size [[buffer(2)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(3)]], \
|
||||
const constant int* shape [[buffer(4)]], \
|
||||
const constant size_t* strides [[buffer(5)]], \
|
||||
const constant int& ndim [[buffer(6)]], \
|
||||
const constant int* reduce_shape [[buffer(7)]], \
|
||||
const constant size_t* reduce_strides [[buffer(8)]], \
|
||||
const constant int& reduce_ndim [[buffer(9)]], \
|
||||
uint3 gid [[threadgroup_position_in_grid]], \
|
||||
uint3 gsize [[threadgroups_per_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
|
||||
template \
|
||||
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general<itype, otype, op>( \
|
||||
[[host_name("rowSimple_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_simple<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 gid [[threadgroup_position_in_grid]], \
|
||||
uint3 gsize [[threadgroups_per_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
|
||||
|
@ -5,6 +5,20 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#define DEFINE_SIMD_REDUCE() \
|
||||
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
|
||||
T simd_reduce(T val) { \
|
||||
return simd_reduce_impl(val); \
|
||||
} \
|
||||
\
|
||||
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
|
||||
T simd_reduce(T val) { \
|
||||
for (short i = simd_size / 2; i > 0; i /= 2) { \
|
||||
val = operator()(val, simd_shuffle_down(val, i)); \
|
||||
} \
|
||||
return val; \
|
||||
}
|
||||
|
||||
static constant constexpr const uint8_t simd_size = 32;
|
||||
|
||||
union bool4_or_uint {
|
||||
@ -14,14 +28,16 @@ union bool4_or_uint {
|
||||
|
||||
struct None {
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U = bool>
|
||||
struct And {
|
||||
bool simd_reduce(bool val) {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
bool simd_reduce_impl(bool val) {
|
||||
return simd_all(val);
|
||||
}
|
||||
|
||||
@ -31,7 +47,7 @@ struct And {
|
||||
device mlx_atomic<unsigned int>* out,
|
||||
bool val,
|
||||
int elem_idx,
|
||||
int offset = 0) {
|
||||
size_t offset = 0) {
|
||||
if (!val) {
|
||||
bool4_or_uint update;
|
||||
update.b = {true, true, true, true};
|
||||
@ -40,7 +56,8 @@ struct And {
|
||||
}
|
||||
}
|
||||
|
||||
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
|
||||
void
|
||||
atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
|
||||
if (!val) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
@ -59,7 +76,9 @@ struct And {
|
||||
|
||||
template <typename U = bool>
|
||||
struct Or {
|
||||
bool simd_reduce(bool val) {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
bool simd_reduce_impl(bool val) {
|
||||
return simd_any(val);
|
||||
}
|
||||
|
||||
@ -68,8 +87,8 @@ struct Or {
|
||||
void atomic_update(
|
||||
device mlx_atomic<unsigned int>* out,
|
||||
bool val,
|
||||
uint elem_idx,
|
||||
uint offset = 0) {
|
||||
int elem_idx,
|
||||
size_t offset = 0) {
|
||||
if (val) {
|
||||
bool4_or_uint update;
|
||||
update.b = {false, false, false, false};
|
||||
@ -78,7 +97,8 @@ struct Or {
|
||||
}
|
||||
}
|
||||
|
||||
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
|
||||
void
|
||||
atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
|
||||
if (val) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
@ -97,15 +117,17 @@ struct Or {
|
||||
|
||||
template <typename U>
|
||||
struct Sum {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
T simd_reduce_impl(T val) {
|
||||
return simd_sum(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = U(0);
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_fetch_add_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
@ -117,15 +139,17 @@ struct Sum {
|
||||
|
||||
template <typename U>
|
||||
struct Prod {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
T simd_reduce_impl(T val) {
|
||||
return simd_product(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = U(1);
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_fetch_mul_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
@ -137,15 +161,17 @@ struct Prod {
|
||||
|
||||
template <typename U>
|
||||
struct Min {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
T simd_reduce_impl(T val) {
|
||||
return simd_min(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_fetch_min_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
@ -157,15 +183,17 @@ struct Min {
|
||||
|
||||
template <typename U>
|
||||
struct Max {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
T simd_reduce_impl(T val) {
|
||||
return simd_max(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_fetch_max_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
|
@ -1,135 +1,61 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
METAL_FUNC U per_thread_all_reduce(
|
||||
const device T* in,
|
||||
const device size_t& in_size,
|
||||
uint gid,
|
||||
uint grid_size) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
if (gid * N_READS < in_size) {
|
||||
in += gid * N_READS;
|
||||
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||
U vals[N_READS] = {op.init};
|
||||
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = static_cast<U>(in[i]);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(vals[i], total_val);
|
||||
}
|
||||
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
T vals[N_READS];
|
||||
|
||||
for (int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||
idx = idx < max_reads ? idx : max_reads - 1;
|
||||
vals[i] = in[idx];
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
U val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// NB: This kernel assumes threads_per_threadgroup is at most
|
||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||
// complete the reduction in two steps of simd-level reductions.
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T* in [[buffer(0)]],
|
||||
device mlx_atomic<U>* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& in_size [[buffer(2)]],
|
||||
const constant size_t& row_size [[buffer(3)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
threadgroup U shared_vals[simd_size];
|
||||
|
||||
U total_val =
|
||||
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
U total = Op::init;
|
||||
int64_t start_idx = gid.y * row_size;
|
||||
int64_t actual_row =
|
||||
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
|
||||
int64_t blocks = actual_row / (lsize.x * N_READS);
|
||||
int extra = actual_row - blocks * (lsize.x * N_READS);
|
||||
extra -= lid.x * N_READS;
|
||||
start_idx += lid.x * N_READS;
|
||||
in += start_idx;
|
||||
|
||||
if (extra >= N_READS) {
|
||||
blocks++;
|
||||
extra = 0;
|
||||
}
|
||||
|
||||
for (int64_t b = 0; b < blocks; b++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total = op(static_cast<U>(in[i]), total);
|
||||
}
|
||||
in += lsize.x * N_READS;
|
||||
}
|
||||
if (extra > 0) {
|
||||
for (int i = 0; i < extra; i++) {
|
||||
total = op(static_cast<U>(in[i]), total);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduction within simd group
|
||||
total_val = op.simd_reduce(total_val);
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
total = op.simd_reduce(total);
|
||||
if (simd_per_group > 1) {
|
||||
if (simd_lane_id == 0) {
|
||||
shared_vals[simd_group_id] = total;
|
||||
}
|
||||
|
||||
// Reduction within thread group
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init;
|
||||
total = op.simd_reduce(total);
|
||||
}
|
||||
|
||||
// Reduction within thread group
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
op.atomic_update(out, total_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce_no_atomics(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val =
|
||||
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group (simd_add isn't supported for uint64/int64
|
||||
// types)
|
||||
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||
lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
// Write simd group reduction results to local memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction of simdgroup reduction results within threadgroup.
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||
lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
out[thread_group_id] = total_val;
|
||||
if (lid.x == 0) {
|
||||
out[gid.y] = total;
|
||||
}
|
||||
}
|
||||
|
@ -1,165 +1,298 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
// Appease the compiler
|
||||
(void)out_size;
|
||||
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 tsize [[threads_per_grid]]) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
|
||||
auto out_idx = tid;
|
||||
// Case 1:
|
||||
// reduction_stride is small, reduction_size is small and non_col_reductions
|
||||
// is small. Each thread computes reduction_stride outputs.
|
||||
if (reduction_size * non_col_reductions < 64) {
|
||||
U totals[31];
|
||||
for (int i = 0; i < 31; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
in += elem_to_loc(
|
||||
out_idx,
|
||||
shape + non_col_ndim,
|
||||
strides + non_col_ndim,
|
||||
ndim - non_col_ndim);
|
||||
short stride = reduction_stride;
|
||||
short size = reduction_size;
|
||||
short blocks = stride / N_READS;
|
||||
short extra = stride - blocks * N_READS;
|
||||
|
||||
for (uint i = 0; i < non_col_reductions; i++) {
|
||||
size_t in_idx =
|
||||
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
||||
U val = static_cast<U>(in[in_idx]);
|
||||
total_val = op(total_val, val);
|
||||
for (uint r = 0; r < non_col_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < blocks; j++) {
|
||||
for (short k = 0; k < N_READS; k++) {
|
||||
totals[j * N_READS + k] =
|
||||
op(totals[j * N_READS + k],
|
||||
static_cast<U>(row[i * stride + j * N_READS + k]));
|
||||
}
|
||||
}
|
||||
for (short k = 0; k < extra; k++) {
|
||||
totals[blocks * N_READS + k] =
|
||||
op(totals[blocks * N_READS + k],
|
||||
static_cast<U>(row[i * stride + blocks * N_READS + k]));
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
out += out_idx * reduction_stride;
|
||||
for (short j = 0; j < stride; j++) {
|
||||
out[j] = totals[j];
|
||||
}
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
METAL_FUNC U _contiguous_strided_reduce(
|
||||
const device T* in,
|
||||
threadgroup U* local_data,
|
||||
uint in_idx,
|
||||
uint reduction_size,
|
||||
uint reduction_stride,
|
||||
uint2 tid,
|
||||
uint2 lid,
|
||||
uint2 lsize) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
||||
for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
||||
uint offset = base_offset + r;
|
||||
total_val =
|
||||
op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
|
||||
}
|
||||
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
U val = Op::init;
|
||||
if (lid.y == 0) {
|
||||
// Perform reduction across columns in thread group
|
||||
for (uint i = 0; i < lsize.y; i++) {
|
||||
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||
// Case 2:
|
||||
// Reduction stride is small but everything else can be big. We loop both
|
||||
// across reduction size and non_col_reductions. Each simdgroup produces
|
||||
// N_READS outputs.
|
||||
else {
|
||||
threadgroup U shared_vals[1024];
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
short stride = reduction_stride;
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 tile((stride + N_READS - 1) / N_READS, 32);
|
||||
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
|
||||
short sm_stride = tile.x * N_READS;
|
||||
bool safe = offset.x + N_READS <= stride;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general(
|
||||
const device T* in [[buffer(0)]],
|
||||
device mlx_atomic<U>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||
// Read cooperatively and contiguously and aggregate the partial results.
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y; r < total; r += simd_size) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
Op op;
|
||||
if (out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
if (safe) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific
|
||||
// output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
op.atomic_update(out, val, out_idx);
|
||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// Each thread holds N_READS partial results but the simdgroups are not
|
||||
// aligned to do the reduction across the simdgroup so we write our results
|
||||
// in the shared memory and read them back according to the simdgroup.
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op.simd_reduce(
|
||||
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
short column = simd_group_id * N_READS;
|
||||
out += out_idx * reduction_stride + column;
|
||||
if (column + N_READS <= stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; column + i < stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general_no_atomics(
|
||||
/**
|
||||
* Our approach is the following simple looped approach:
|
||||
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
|
||||
* 2. Load a tile BM, BN in shared memory.
|
||||
* 3. Add the values from shared memory to the current running totals.
|
||||
* Neighboring threads access different rows (transposed acces).
|
||||
* 4. Move ahead to the next tile until the M axis is exhausted.
|
||||
* 5. Move ahead to the next non column reduction
|
||||
* 6. Simd reduce the running totals
|
||||
* 7. Write them to the output
|
||||
*
|
||||
* The kernel becomes verbose because we support all kinds of OOB checks. For
|
||||
* instance if we choose that reduction_stride must be larger than BN then we
|
||||
* can get rid of half the kernel.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int BM = 8,
|
||||
int BN = 128>
|
||||
[[kernel]] void col_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
constexpr int n_simdgroups = 4;
|
||||
constexpr short tgp_size = n_simdgroups * simd_size;
|
||||
constexpr short n_reads = (BM * BN) / tgp_size;
|
||||
constexpr short n_read_blocks = BN / n_reads;
|
||||
|
||||
if (out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
threadgroup U shared_vals[BN * BM];
|
||||
U totals[n_reads];
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific
|
||||
// output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
||||
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
||||
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||
size_t column = BN * gid.x + offset.x;
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y; r < total; r += BM) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
vals[i] =
|
||||
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(BM, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// We can use a simd reduction to accumulate across BM so each thread writes
|
||||
// the partial output to SM and then each simdgroup does BN / n_simdgroups
|
||||
// accumulations.
|
||||
if (BM == 32) {
|
||||
constexpr int n_outputs = BN / n_simdgroups;
|
||||
static_assert(
|
||||
BM != 32 || n_outputs == n_reads,
|
||||
"The tile should be selected such that n_outputs == n_reads");
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
shared_vals[offset.y * BN + offset.x + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] =
|
||||
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
size_t out_column = BN * gid.x + out_offset.x;
|
||||
out += out_idx * reduction_stride + out_column;
|
||||
if (out_column + n_outputs <= reduction_stride) {
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; out_column + i < reduction_stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Each thread holds n_reads partial results. We write them all out to shared
|
||||
// memory and threads with offset.y == 0 aggregate the columns and write the
|
||||
// outputs.
|
||||
else {
|
||||
short x_block = offset.x / n_reads;
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (offset.y == 0) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
for (int j = 1; j < BM; j++) {
|
||||
totals[i] =
|
||||
op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (offset.y == 0) {
|
||||
out += out_idx * reduction_stride + column;
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; column + i < reduction_stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,287 +1,366 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small row reductions
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Row reduction utilities
|
||||
// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup
|
||||
// - `threadgroup_reduce` collaborative reduction in the threadgroup such that
|
||||
// lid.x == 0 holds the reduced value
|
||||
// - `thread_reduce` simple loop and reduce the row
|
||||
|
||||
// Each thread reduces for one output
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void row_reduce_general_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]) {
|
||||
/**
|
||||
* The thread group collaboratively reduces across the rows with bounds
|
||||
* checking. In the end each thread holds a part of the reduction.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
METAL_FUNC void per_thread_row_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
const device T* inputs[N_WRITES],
|
||||
int blocks,
|
||||
int extra,
|
||||
uint lsize_x,
|
||||
uint lid_x) {
|
||||
Op op;
|
||||
|
||||
uint out_idx = lid;
|
||||
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
// Set up the accumulator registers
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
// Loop over the reduction size within thread group
|
||||
for (int i = 0; i < blocks; i++) {
|
||||
for (int j = 0; j < N_WRITES; j++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
|
||||
}
|
||||
|
||||
for (short r = 0; r < short(non_row_reductions); r++) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for (short i = 0; i < short(reduction_size); i++) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
inputs[j] += lsize_x * N_READS;
|
||||
}
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
// Each simdgroup reduces for one output
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void row_reduce_general_med(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
|
||||
uint out_idx = simd_per_group * tid + simd_group_id;
|
||||
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
|
||||
if (short(non_row_reductions) == 1) {
|
||||
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for (short i = simd_lane_id; i < short(reduction_size); i += 32) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
else if (short(non_row_reductions) >= 32) {
|
||||
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for (short i = 0; i < short(reduction_size); i++) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
// Separate case for the last set as we close the reduction size
|
||||
int index = lid_x * N_READS;
|
||||
if (index + N_READS <= extra) {
|
||||
for (int j = 0; j < N_WRITES; j++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
else {
|
||||
const short n_reductions =
|
||||
short(reduction_size) * short(non_row_reductions);
|
||||
const short reductions_per_thread =
|
||||
(n_reductions + simd_size - 1) / simd_size;
|
||||
|
||||
const short r_st = simd_lane_id / reductions_per_thread;
|
||||
const short r_ed = short(non_row_reductions);
|
||||
const short r_jump = simd_size / reductions_per_thread;
|
||||
|
||||
const short i_st = simd_lane_id % reductions_per_thread;
|
||||
const short i_ed = short(reduction_size);
|
||||
const short i_jump = reductions_per_thread;
|
||||
|
||||
if (r_st < r_jump) {
|
||||
for (short r = r_st; r < r_ed; r += r_jump) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for (short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < N_WRITES; j++) {
|
||||
for (int i = 0; index + i < extra; i++) {
|
||||
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
if (simd_lane_id == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Large row reductions
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
METAL_FUNC U per_thread_row_reduce(
|
||||
/**
|
||||
* Consecutive rows in a contiguous array.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
METAL_FUNC void per_thread_row_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
const device T* in,
|
||||
const constant size_t& reduction_size,
|
||||
const constant size_t& out_size,
|
||||
int blocks,
|
||||
int extra,
|
||||
uint lsize_x,
|
||||
uint lid_x) {
|
||||
// Set up the input pointers
|
||||
const device T* inputs[N_WRITES];
|
||||
inputs[0] = in + lid_x * N_READS;
|
||||
for (int i = 1; i < N_READS; i++) {
|
||||
inputs[i] = inputs[i - 1] + reduction_size;
|
||||
}
|
||||
|
||||
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, inputs, blocks, extra, lsize_x, lid_x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Consecutive rows in an arbitrarily ordered array.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
METAL_FUNC void per_thread_row_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
const device T* in,
|
||||
const size_t row_idx,
|
||||
int blocks,
|
||||
int extra,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
const constant int& ndim,
|
||||
uint lsize_x,
|
||||
uint lid_x,
|
||||
uint2 tid) {
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
// TODO: Specializing elem_to_loc would be slightly faster
|
||||
int idx = tid.y * out_size + tid.x;
|
||||
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||
in += extra_offset + lid_x * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize_x * N_READS;
|
||||
uint lid_x) {
|
||||
// Set up the input pointers
|
||||
const device T* inputs[N_WRITES];
|
||||
in += lid_x * N_READS;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim);
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
|
||||
if (reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
int idx = min(i, max_reads - 1);
|
||||
vals[i] = static_cast<U>(in[idx]);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, inputs, blocks, extra, lsize_x, lid_x);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general(
|
||||
const device T* in [[buffer(0)]],
|
||||
device mlx_atomic<U>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
/**
|
||||
* Reduce within the threadgroup.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
METAL_FUNC void threadgroup_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
threadgroup U* shared_vals,
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
(void)non_row_reductions;
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
reduction_size,
|
||||
out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim,
|
||||
lsize.x,
|
||||
lid.x,
|
||||
tid.xy);
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
// Simdgroup first
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
totals[i] = op.simd_reduce(totals[i]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if multiple simd groups
|
||||
if (reduction_size > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
}
|
||||
// Update output
|
||||
if (lid.x == 0) {
|
||||
op.atomic_update(out, total_val, tid.x);
|
||||
// Across simdgroups
|
||||
if (simd_per_group > 1) {
|
||||
if (simd_lane_id == 0) {
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
shared_vals[simd_group_id * N_WRITES + i] = totals[i];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
U values[N_WRITES];
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
|
||||
: op.init;
|
||||
}
|
||||
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
totals[i] = op.simd_reduce(values[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general_no_atomics(
|
||||
METAL_FUNC void
|
||||
thread_reduce(thread U& total, const device T* row, int blocks, int extra) {
|
||||
Op op;
|
||||
for (int i = 0; i < blocks; i++) {
|
||||
U vals[N_READS];
|
||||
for (int j = 0; j < N_READS; j++) {
|
||||
vals[j] = row[j];
|
||||
}
|
||||
for (int j = 0; j < N_READS; j++) {
|
||||
total = op(vals[j], total);
|
||||
}
|
||||
row += N_READS;
|
||||
}
|
||||
for (int i = 0; i < extra; i++) {
|
||||
total = op(*row++, total);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduction kernels
|
||||
// - `row_reduce_small` depending on the non-row reductions and row size it
|
||||
// either just loops over everything or a simd collaboratively reduces the
|
||||
// non_row reductions. In the first case one thread is responsible for one
|
||||
// output on the 2nd one simd is responsible for one output.
|
||||
// - `row_reduce_simple` simple contiguous row reduction
|
||||
// - `row_reduce_looped` simply loop and reduce each row for each non-row
|
||||
// reduction. One threadgroup is responsible for one output.
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& row_size [[buffer(2)]],
|
||||
const constant size_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 tsize [[threads_per_grid]]) {
|
||||
Op op;
|
||||
|
||||
U total_val = Op::init;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
|
||||
// Precompute some row reduction numbers
|
||||
const device T* row;
|
||||
int blocks = row_size / N_READS;
|
||||
int extra = row_size % N_READS;
|
||||
|
||||
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
||||
// Simple loop over non_row_reductions and reduce the row in the thread.
|
||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint r = 0; r < non_row_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
} else {
|
||||
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
|
||||
// thread reduces every 32nd row and then a simple simd reduce.
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
|
||||
loop.next(simd_lane_id, reduce_shape, reduce_strides);
|
||||
|
||||
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
if (simd_lane_id == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
[[kernel]] void row_reduce_simple(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
(void)non_row_reductions;
|
||||
threadgroup U shared_vals[simd_size * N_WRITES];
|
||||
U totals[N_WRITES];
|
||||
|
||||
Op op;
|
||||
|
||||
threadgroup U local_vals[simd_size];
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
reduction_size,
|
||||
out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim,
|
||||
lsize.x,
|
||||
lid.x,
|
||||
tid.xy);
|
||||
|
||||
// Reduction within simd group - simd_add isn't supported for int64 types
|
||||
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
// Move to the row
|
||||
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
|
||||
if (out_idx + N_WRITES > out_size) {
|
||||
out_idx = out_size - N_WRITES;
|
||||
}
|
||||
in += out_idx * reduction_size;
|
||||
out += out_idx;
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Each thread reduces across the row
|
||||
int blocks = reduction_size / (lsize.x * N_READS);
|
||||
int extra = reduction_size - blocks * (lsize.x * N_READS);
|
||||
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if thread group has multiple simd groups
|
||||
if (ceildiv(reduction_size, N_READS) > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
// Reduce across the threadgroup
|
||||
threadgroup_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
|
||||
|
||||
// Write the output
|
||||
if (lid.x == 0) {
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
// Write row reduce output for threadgroup with 1st thread in thread group
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& row_size [[buffer(2)]],
|
||||
const constant size_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
threadgroup U shared_vals[simd_size];
|
||||
U total = Op::init;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
|
||||
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
|
||||
// needs a small refactor.
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
|
||||
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
int blocks = row_size / (lsize.x * N_READS);
|
||||
int extra = row_size - blocks * (lsize.x * N_READS);
|
||||
|
||||
for (size_t i = 0; i < non_row_reductions; i++) {
|
||||
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
// Each thread reduces across the row
|
||||
U row_total;
|
||||
per_thread_row_reduce<T, U, Op, N_READS, 1>(
|
||||
&row_total, &row, blocks, extra, lsize.x, lid.x);
|
||||
|
||||
// Aggregate across rows
|
||||
total = op(total, row_total);
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// Reduce across the threadgroup
|
||||
threadgroup_reduce<T, U, Op, N_READS, 1>(
|
||||
&total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
|
||||
|
||||
// Write the output
|
||||
if (lid.x == 0) {
|
||||
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
|
||||
out[out_idx] = total;
|
||||
}
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
|
||||
uint out_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; i++) {
|
||||
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||
out_idx += idx_val * out_strides[i];
|
||||
|
@ -64,6 +64,16 @@ struct Limits<bool> {
|
||||
static constexpr constant bool min = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<complex64_t> {
|
||||
static constexpr constant complex64_t max = complex64_t(
|
||||
metal::numeric_limits<float>::infinity(),
|
||||
metal::numeric_limits<float>::infinity());
|
||||
static constexpr constant complex64_t min = complex64_t(
|
||||
-metal::numeric_limits<float>::infinity(),
|
||||
-metal::numeric_limits<float>::infinity());
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -101,6 +111,34 @@ METAL_FUNC stride_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
stride_t elem,
|
||||
device const int* shape,
|
||||
device const stride_t* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
stride_t elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
@ -288,12 +326,87 @@ METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int dim, typename offset_t = size_t>
|
||||
struct looped_elem_to_loc {
|
||||
looped_elem_to_loc<dim - 1, offset_t> inner_looper;
|
||||
offset_t offset{0};
|
||||
int index{0};
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
index++;
|
||||
offset += strides[dim - 1];
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
}
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
index += n;
|
||||
offset += n * strides[dim - 1];
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<1, offset_t> {
|
||||
offset_t offset{0};
|
||||
|
||||
void next(const constant int*, const constant size_t* strides) {
|
||||
offset += strides[0];
|
||||
}
|
||||
|
||||
void next(int n, const constant int*, const constant size_t* strides) {
|
||||
offset += n * strides[0];
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<0, offset_t> {
|
||||
void next(const constant int*, const constant size_t*) {}
|
||||
void next(int, const constant int*, const constant size_t*) {}
|
||||
|
||||
offset_t location(
|
||||
offset_t idx,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
int ndim) {
|
||||
return elem_to_loc(idx, shape, strides, ndim);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** Compute ceil((float)N/(float)M) */
|
||||
inline size_t ceildiv(size_t N, size_t M) {
|
||||
template <typename T, typename U>
|
||||
inline T ceildiv(T N, U M) {
|
||||
return (N + M - 1) / M;
|
||||
}
|
||||
|
||||
@ -339,3 +452,8 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||
}
|
||||
|
||||
inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
|
||||
return complex64_t(
|
||||
simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
|
||||
}
|
||||
|
@ -135,8 +135,8 @@ void RMSNormVJP::eval_gpu(
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
// Allocate the gradient accumulator gw and a temporary to store the
|
||||
// gradients before they are accumulated.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
@ -146,11 +146,7 @@ void RMSNormVJP::eval_gpu(
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
@ -330,8 +326,8 @@ void LayerNormVJP::eval_gpu(
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
@ -341,12 +337,8 @@ void LayerNormVJP::eval_gpu(
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copy_gpu(zero, gb, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
File diff suppressed because it is too large
Load Diff
158
mlx/ops.cpp
158
mlx/ops.cpp
@ -16,9 +16,11 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool>
|
||||
compute_reduce_shape(
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& shape) {
|
||||
bool is_noop = true;
|
||||
std::set<int> axes_set;
|
||||
auto ndim = shape.size();
|
||||
for (auto ax : axes) {
|
||||
@ -35,15 +37,18 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
throw std::invalid_argument("Duplicate axes detected in reduction.");
|
||||
}
|
||||
std::vector<int> out_shape;
|
||||
std::vector<int> squeezed_shape;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
if (axes_set.count(i) == 0) {
|
||||
out_shape.push_back(shape[i]);
|
||||
squeezed_shape.push_back(shape[i]);
|
||||
} else {
|
||||
out_shape.push_back(1);
|
||||
}
|
||||
is_noop &= (out_shape.back() == shape[i]);
|
||||
}
|
||||
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
|
||||
return {out_shape, sorted_axes};
|
||||
return {out_shape, sorted_axes, squeezed_shape, is_noop};
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
@ -1502,17 +1507,17 @@ array all(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.empty()) {
|
||||
return astype(a, bool_, s);
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? astype(a, bool_, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -1536,17 +1541,17 @@ array any(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.empty()) {
|
||||
return astype(a, bool_, s);
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? astype(a, bool_, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -1573,15 +1578,18 @@ array sum(
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
|
||||
auto out = array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
||||
{a});
|
||||
auto out = (is_noop)
|
||||
? astype(a, out_type, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -1715,14 +1723,17 @@ array prod(
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -1749,17 +1760,17 @@ array max(
|
||||
if (a.size() == 0) {
|
||||
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
|
||||
}
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -1789,14 +1800,17 @@ array min(
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -1829,15 +1843,18 @@ array argmin(
|
||||
throw std::invalid_argument(
|
||||
"[argmin] Cannot argmin reduce zero size array.");
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape({axis}, a.shape());
|
||||
auto out = (is_noop)
|
||||
? zeros(out_shape, uint32, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -1862,15 +1879,18 @@ array argmax(
|
||||
throw std::invalid_argument(
|
||||
"[argmax] Cannot argmax reduce zero size array.");
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape({axis}, a.shape());
|
||||
auto out = (is_noop)
|
||||
? zeros(out_shape, uint32, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
@ -1763,7 +1763,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
mat_t = mat.astype(t)
|
||||
out = mx.cumsum(a_t, axis=-1)
|
||||
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
|
||||
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
|
||||
self.assertTrue(mx.allclose(out, expected, rtol=0.02, atol=1e-3))
|
||||
sizes = [1023, 1024, 1025, 2047, 2048, 2049]
|
||||
for s in sizes:
|
||||
a = mx.ones((s,), mx.int32)
|
||||
|
@ -43,6 +43,10 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
z_npy = np.sum(y_npy, axis=a) / 1000
|
||||
z_mlx = mx.sum(y_mlx, axis=a) / 1000
|
||||
mx.eval(z_mlx)
|
||||
if not np.allclose(z_npy, np.array(z_mlx), atol=1e-4):
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
self.assertTrue(
|
||||
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
||||
)
|
||||
@ -57,6 +61,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
"uint32",
|
||||
"int64",
|
||||
"uint64",
|
||||
"complex64",
|
||||
]
|
||||
float_dtypes = ["float32"]
|
||||
|
||||
@ -114,6 +119,15 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
b = getattr(np, op)(data)
|
||||
self.assertEqual(a.item(), b)
|
||||
|
||||
def test_edge_case(self):
|
||||
x = (mx.random.normal((100, 1, 100, 100)) * 128).astype(mx.int32)
|
||||
x = x.transpose(0, 3, 1, 2)
|
||||
|
||||
y = x.sum((0, 2, 3))
|
||||
mx.eval(y)
|
||||
z = np.array(x).sum((0, 2, 3))
|
||||
self.assertTrue(np.all(z == y))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Loading…
Reference in New Issue
Block a user