Working 64-bit scans (#1506)

This commit is contained in:
Angelos Katharopoulos 2024-10-24 11:05:46 -07:00 committed by GitHub
parent 32972a5924
commit c9b41d460f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 309 additions and 137 deletions

View File

@ -1,26 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view scan_kernels = R"(
template [[host_name("contig_{0}")]] [[kernel]] void
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("strided_{0}")]] [[kernel]] void
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]);
)";

View File

@ -4,7 +4,6 @@
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
@ -224,18 +223,26 @@ MTL::ComputePipelineState* get_scan_kernel(
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
auto out_type = get_type_string(out.dtype());
std::string op = "Cum" + reduce_type + "<" + out_type + ">";
op[3] = toupper(op[3]);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
scan_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name,
inclusive,
reverse);
kernel_source << metal::utils() << metal::scan();
const std::array<std::pair<std::string, std::string>, 2> scan_kernels = {{
{"contig_", "contiguous_scan"},
{"strided_", "strided_scan"},
}};
for (auto& [prefix, kernel] : scan_kernels) {
kernel_source << get_template_definition(
prefix + lib_name,
kernel,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op,
in.itemsize() <= 4 ? 4 : 2,
inclusive,
reverse);
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);

View File

@ -1,7 +1,38 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \
return simd_scan_impl(val); \
} \
\
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
T simd_scan(T val) { \
for (int i = 1; i <= 16; i *= 2) { \
val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \
} \
return val; \
}
#define DEFINE_SIMD_EXCLUSIVE_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_exclusive_scan(T val) { \
return simd_exclusive_scan_impl(val); \
} \
\
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
T simd_exclusive_scan(T val) { \
val = simd_scan(val); \
return simd_shuffle_and_fill_up(val, init, 1); \
}
template <typename U>
struct CumSum {
DEFINE_SIMD_SCAN()
DEFINE_SIMD_EXCLUSIVE_SCAN()
static constexpr constant U init = static_cast<U>(0);
template <typename T>
@ -9,17 +40,20 @@ struct CumSum {
return a + b;
}
U simd_scan(U x) {
U simd_scan_impl(U x) {
return simd_prefix_inclusive_sum(x);
}
U simd_exclusive_scan(U x) {
U simd_exclusive_scan_impl(U x) {
return simd_prefix_exclusive_sum(x);
}
};
template <typename U>
struct CumProd {
DEFINE_SIMD_SCAN()
DEFINE_SIMD_EXCLUSIVE_SCAN()
static constexpr constant U init = static_cast<U>(1.0f);
template <typename T>
@ -27,11 +61,11 @@ struct CumProd {
return a * b;
}
U simd_scan(U x) {
U simd_scan_impl(U x) {
return simd_prefix_inclusive_product(x);
}
U simd_exclusive_scan(U x) {
U simd_exclusive_scan_impl(U x) {
return simd_prefix_exclusive_product(x);
}
};
@ -47,7 +81,7 @@ struct CumProd<bool> {
bool simd_scan(bool x) {
for (int i = 1; i <= 16; i *= 2) {
bool other = simd_shuffle_up(x, i);
bool other = simd_shuffle_and_fill_up(x, init, i);
x &= other;
}
return x;
@ -70,7 +104,7 @@ struct CumMax {
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
U other = simd_shuffle_and_fill_up(x, init, i);
x = (x >= other) ? x : other;
}
return x;
@ -93,7 +127,7 @@ struct CumMin {
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
U other = simd_shuffle_and_fill_up(x, init, i);
x = (x <= other) ? x : other;
}
return x;
@ -178,20 +212,22 @@ template <
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
Op op;
// Position the pointers
in += (gid / lsize) * axis_size;
out += (gid / lsize) * axis_size;
size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size;
in += offset;
out += offset;
// Compute the number of simd_groups
uint simd_groups = lsize / simd_size;
uint simd_groups = lsize.x / simd_size;
// Allocate memory
U prefix = Op::init;
@ -210,9 +246,9 @@ template <
// value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {
// Compute the block offset
uint offset = r * lsize * N_READS + lid * N_READS;
uint offset = r * lsize.x * N_READS + lid.x * N_READS;
// Read the values
if (reverse) {
@ -275,7 +311,7 @@ template <
values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
if (lid.x == 0 && offset == 0) {
out[axis_size - 1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
@ -298,7 +334,7 @@ template <
values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
if (lid.x == 0 && offset == 0) {
out[0] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
@ -332,86 +368,98 @@ template <
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]) {
const constant size_t& stride_blocks [[buffer(4)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
constexpr int BM = 32;
constexpr int BN = 32;
constexpr int BN_pad = 32 + 16 / sizeof(U);
constexpr int n_simds = BN / N_READS;
constexpr int n_scans = BN / n_simds;
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
U values[N_READS];
U prefix[N_READS];
for (int i = 0; i < N_READS; i++) {
threadgroup U read_buffer[BM * BN_pad];
U values[n_scans];
U prefix[n_scans];
for (int i = 0; i < n_scans; i++) {
prefix[i] = Op::init;
}
// Compute offsets
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
size_t full_gid = gid.y + gsize.y * size_t(gid.z);
size_t offset = full_gid / stride_blocks * axis_size * stride;
size_t global_index_x = full_gid % stride_blocks * BN;
uint read_offset_y = (lid.x * N_READS) / BN;
uint read_offset_x = (lid.x * N_READS) % BN;
uint scan_offset_y = simd_lane_id;
uint scan_offset_x = simd_group_id * n_scans;
for (uint j = 0; j < axis_size; j += simd_size) {
uint stride_limit = stride - global_index_x;
in += offset + global_index_x + read_offset_x;
out += offset + global_index_x + read_offset_x;
threadgroup U* read_into =
read_buffer + read_offset_y * BN_pad + read_offset_x;
threadgroup U* read_from =
read_buffer + scan_offset_y * BN_pad + scan_offset_x;
for (uint j = 0; j < axis_size; j += BM) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint index_y = j + read_offset_y;
uint check_index_y = index_y;
uint index_x = global_index_x + lid.x * N_READS;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
read_into[i] = in[index_y * stride + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
read_into[i] = in[index_y * stride + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
Op::init;
read_into[i] = Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i = 0; i < N_READS; i++) {
values[i] =
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
for (int i = 0; i < n_scans; i++) {
values[i] = read_from[i];
}
// Do we need the following barrier? Shouldn't all simd threads execute
// simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i = 0; i < N_READS; i++) {
for (int i = 0; i < n_scans; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size - 1);
}
// Write to SM
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
values[i];
for (int i = 0; i < n_scans; i++) {
read_from[i] = values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write to device memory
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
if ((read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
out[index_y * stride + i] = Op::init;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
if ((read_offset_x + i) < stride_limit) {
out[index_y * stride + i] = Op::init;
}
}
}
@ -424,16 +472,14 @@ template <
check_index_y += 1;
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
out[index_y * stride + i] = read_into[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
out[index_y * stride + i] = read_into[i];
}
}
}

View File

@ -13,15 +13,15 @@ using namespace metal;
#define instantiate_contiguous_scan( \
name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contig_scan_" #name)]] [[kernel]] void \
template [[host_name("contig_scan_" #name)]] [[kernel]] void \
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
@ -33,10 +33,12 @@ using namespace metal;
device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \
const constant size_t& stride [[buffer(3)]], \
uint2 gid [[thread_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
const constant size_t& stride_blocks [[buffer(4)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
@ -52,51 +54,51 @@ instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSu
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)
//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4)
instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4)
instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4)
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum, 2)
instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4)
instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4)
//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4)
instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4)
instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4)
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd, 2)
instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4)
instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4)
//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4)
instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4)
instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4)
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax, 2)
instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4)
instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4)
//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4)
instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4)
instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4)
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on

View File

@ -320,3 +320,63 @@ inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
return complex64_t(
simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
}
inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_up(bool data, uint16_t delta) {
return simd_shuffle_up(static_cast<uint32_t>(data), delta);
}
inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) {
return complex64_t(
simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta));
}
inline uint64_t
simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
as_type<uint2>(data), as_type<uint2>(filling), delta));
}
inline int64_t
simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
as_type<uint2>(data), as_type<uint2>(filling), delta));
}
inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
return simd_shuffle_and_fill_up(
static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
}
inline complex64_t simd_shuffle_and_fill_up(
complex64_t data,
complex64_t filling,
uint16_t delta) {
return complex64_t(
simd_shuffle_and_fill_up(data.real, filling.real, delta),
simd_shuffle_and_fill_up(data.imag, filling.imag, delta));
}
inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
}
inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
}
inline bool simd_shuffle(bool data, uint16_t lane) {
return simd_shuffle(static_cast<uint32_t>(data), lane);
}
inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
return complex64_t(
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
}

View File

@ -14,19 +14,27 @@ namespace mlx::core {
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure contiguity
std::vector<array> copies;
auto in = inputs[0];
if (!in.flags().row_contiguous) {
if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
out.move_shared_buffer(in);
}
bool contiguous = in.strides()[axis_] == 1;
@ -61,7 +69,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
size_t size = in.shape(axis_);
compute_encoder->setBytes(&size, sizeof(size_t), 2);
@ -70,7 +79,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
constexpr int simd_size = 32;
int elements_per_simd = n_reads * simd_size;
int thread_groups = in.size() / size;
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (size <= n_reads * 1024) {
thread_group_size =
@ -82,28 +90,41 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
thread_group_size = std::min(
thread_group_size,
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto tmp_grid_dims =
get_2d_grid_dims(in.shape(), in.strides(), /** divisor= */ size);
MTL::Size grid_dims(
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
MTL::Size group_dims(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
size_t size = in.shape(axis_);
size_t stride = in.strides()[axis_];
int bm = 32;
int bn = 32;
size_t stride_blocks = (stride + bn - 1) / bn;
compute_encoder->setBytes(&size, sizeof(size_t), 2);
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4);
// Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
int tile_x = 32;
int tile_y = 32;
int elements_per_tile_x = tile_x * n_reads;
int grid_y = in.size() / size / stride;
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
int n_simdgroups = bn / n_reads;
int thread_group_size = n_simdgroups * 32;
auto tmp_grid_dims = get_2d_grid_dims(
in.shape(), in.strides(), /** divisor= */ size * stride);
if (tmp_grid_dims.width * stride_blocks <= UINT_MAX) {
tmp_grid_dims.width *= stride_blocks;
} else {
tmp_grid_dims.height *= stride_blocks;
}
MTL::Size grid_dims(
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
MTL::Size group_dims(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}

View File

@ -107,6 +107,48 @@ MTL::Size get_2d_grid_dims(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
MTL::Size get_2d_grid_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t divisor) {
// Compute the 2d grid dimensions such that the total size of the grid is
// divided by divisor.
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
// No need to add this shape we can just remove it from the divisor.
if (divisor % shape[i] == 0) {
divisor /= shape[i];
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
if (divisor > 1) {
if (grid_x % divisor == 0) {
grid_x /= divisor;
divisor = 1;
} else if (grid_y % divisor == 0) {
grid_y /= divisor;
divisor = 1;
}
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
throw std::runtime_error("Unable to safely factor shape.");
}
return MTL::Size(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);

View File

@ -42,6 +42,14 @@ MTL::Size get_2d_grid_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides);
// Same as above but we do an implicit division with divisor.
// Basically, equivalent to factorizing
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
MTL::Size get_2d_grid_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t divisor);
inline NS::String* make_string(std::ostringstream& os) {
std::string string = os.str();
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);

View File

@ -1758,6 +1758,18 @@ class TestOps(mlx_tests.MLXTestCase):
c_mlx = mxop(a_mlx, axis=axis)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
for dt in [mx.int32, mx.int64]:
mxx = a_mlx.astype(dt)
npx = np.array(mxx)
for op in ["cumsum", "cumprod"]:
npop = getattr(np, op)
mxop = getattr(mx, op)
for axis in (None, 0, 1, 2):
c_npy = npop(npx, axis=axis, dtype=npx.dtype)
c_mlx = mxop(mxx, axis=axis)
self.assertTrue(np.array_equal(c_npy, c_mlx))
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
mxop = getattr(mx, op)