mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Working 64-bit scans (#1506)
This commit is contained in:
parent
32972a5924
commit
c9b41d460f
@ -1,26 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view scan_kernels = R"(
|
||||
template [[host_name("contig_{0}")]] [[kernel]] void
|
||||
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
template [[host_name("strided_{0}")]] [[kernel]] void
|
||||
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[thread_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
)";
|
@ -4,7 +4,6 @@
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
@ -224,18 +223,26 @@ MTL::ComputePipelineState* get_scan_kernel(
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string op_name = "Cum" + reduce_type;
|
||||
op_name[3] = toupper(op_name[3]);
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = "Cum" + reduce_type + "<" + out_type + ">";
|
||||
op[3] = toupper(op[3]);
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::scan()
|
||||
<< fmt::format(
|
||||
scan_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name,
|
||||
inclusive,
|
||||
reverse);
|
||||
kernel_source << metal::utils() << metal::scan();
|
||||
const std::array<std::pair<std::string, std::string>, 2> scan_kernels = {{
|
||||
{"contig_", "contiguous_scan"},
|
||||
{"strided_", "strided_scan"},
|
||||
}};
|
||||
for (auto& [prefix, kernel] : scan_kernels) {
|
||||
kernel_source << get_template_definition(
|
||||
prefix + lib_name,
|
||||
kernel,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op,
|
||||
in.itemsize() <= 4 ? 4 : 2,
|
||||
inclusive,
|
||||
reverse);
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@ -1,7 +1,38 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define DEFINE_SIMD_SCAN() \
|
||||
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
|
||||
T simd_scan(T val) { \
|
||||
return simd_scan_impl(val); \
|
||||
} \
|
||||
\
|
||||
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
|
||||
T simd_scan(T val) { \
|
||||
for (int i = 1; i <= 16; i *= 2) { \
|
||||
val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \
|
||||
} \
|
||||
return val; \
|
||||
}
|
||||
|
||||
#define DEFINE_SIMD_EXCLUSIVE_SCAN() \
|
||||
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
|
||||
T simd_exclusive_scan(T val) { \
|
||||
return simd_exclusive_scan_impl(val); \
|
||||
} \
|
||||
\
|
||||
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
|
||||
T simd_exclusive_scan(T val) { \
|
||||
val = simd_scan(val); \
|
||||
return simd_shuffle_and_fill_up(val, init, 1); \
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
struct CumSum {
|
||||
DEFINE_SIMD_SCAN()
|
||||
DEFINE_SIMD_EXCLUSIVE_SCAN()
|
||||
|
||||
static constexpr constant U init = static_cast<U>(0);
|
||||
|
||||
template <typename T>
|
||||
@ -9,17 +40,20 @@ struct CumSum {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
U simd_scan_impl(U x) {
|
||||
return simd_prefix_inclusive_sum(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
U simd_exclusive_scan_impl(U x) {
|
||||
return simd_prefix_exclusive_sum(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumProd {
|
||||
DEFINE_SIMD_SCAN()
|
||||
DEFINE_SIMD_EXCLUSIVE_SCAN()
|
||||
|
||||
static constexpr constant U init = static_cast<U>(1.0f);
|
||||
|
||||
template <typename T>
|
||||
@ -27,11 +61,11 @@ struct CumProd {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
U simd_scan_impl(U x) {
|
||||
return simd_prefix_inclusive_product(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
U simd_exclusive_scan_impl(U x) {
|
||||
return simd_prefix_exclusive_product(x);
|
||||
}
|
||||
};
|
||||
@ -47,7 +81,7 @@ struct CumProd<bool> {
|
||||
|
||||
bool simd_scan(bool x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
bool other = simd_shuffle_up(x, i);
|
||||
bool other = simd_shuffle_and_fill_up(x, init, i);
|
||||
x &= other;
|
||||
}
|
||||
return x;
|
||||
@ -70,7 +104,7 @@ struct CumMax {
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
U other = simd_shuffle_and_fill_up(x, init, i);
|
||||
x = (x >= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
@ -93,7 +127,7 @@ struct CumMin {
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
U other = simd_shuffle_and_fill_up(x, init, i);
|
||||
x = (x <= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
@ -178,20 +212,22 @@ template <
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int simd_size = 32;
|
||||
Op op;
|
||||
|
||||
// Position the pointers
|
||||
in += (gid / lsize) * axis_size;
|
||||
out += (gid / lsize) * axis_size;
|
||||
size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size;
|
||||
in += offset;
|
||||
out += offset;
|
||||
|
||||
// Compute the number of simd_groups
|
||||
uint simd_groups = lsize / simd_size;
|
||||
uint simd_groups = lsize.x / simd_size;
|
||||
|
||||
// Allocate memory
|
||||
U prefix = Op::init;
|
||||
@ -210,9 +246,9 @@ template <
|
||||
// value
|
||||
// Write block
|
||||
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {
|
||||
// Compute the block offset
|
||||
uint offset = r * lsize * N_READS + lid * N_READS;
|
||||
uint offset = r * lsize.x * N_READS + lid.x * N_READS;
|
||||
|
||||
// Read the values
|
||||
if (reverse) {
|
||||
@ -275,7 +311,7 @@ template <
|
||||
values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
if (lid.x == 0 && offset == 0) {
|
||||
out[axis_size - 1] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
@ -298,7 +334,7 @@ template <
|
||||
values, out + offset, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
if (lid.x == 0 && offset == 0) {
|
||||
out[0] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
@ -332,86 +368,98 @@ template <
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]) {
|
||||
const constant size_t& stride_blocks [[buffer(4)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int simd_size = 32;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
constexpr int BN_pad = 32 + 16 / sizeof(U);
|
||||
constexpr int n_simds = BN / N_READS;
|
||||
constexpr int n_scans = BN / n_simds;
|
||||
Op op;
|
||||
|
||||
// Allocate memory
|
||||
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
|
||||
U values[N_READS];
|
||||
U prefix[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
threadgroup U read_buffer[BM * BN_pad];
|
||||
U values[n_scans];
|
||||
U prefix[n_scans];
|
||||
for (int i = 0; i < n_scans; i++) {
|
||||
prefix[i] = Op::init;
|
||||
}
|
||||
|
||||
// Compute offsets
|
||||
int offset = gid.y * axis_size * stride;
|
||||
int global_index_x = gid.x * lsize.y * N_READS;
|
||||
size_t full_gid = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t offset = full_gid / stride_blocks * axis_size * stride;
|
||||
size_t global_index_x = full_gid % stride_blocks * BN;
|
||||
uint read_offset_y = (lid.x * N_READS) / BN;
|
||||
uint read_offset_x = (lid.x * N_READS) % BN;
|
||||
uint scan_offset_y = simd_lane_id;
|
||||
uint scan_offset_x = simd_group_id * n_scans;
|
||||
|
||||
for (uint j = 0; j < axis_size; j += simd_size) {
|
||||
uint stride_limit = stride - global_index_x;
|
||||
in += offset + global_index_x + read_offset_x;
|
||||
out += offset + global_index_x + read_offset_x;
|
||||
threadgroup U* read_into =
|
||||
read_buffer + read_offset_y * BN_pad + read_offset_x;
|
||||
threadgroup U* read_from =
|
||||
read_buffer + scan_offset_y * BN_pad + scan_offset_x;
|
||||
|
||||
for (uint j = 0; j < axis_size; j += BM) {
|
||||
// Calculate the indices for the current thread
|
||||
uint index_y = j + lid.y;
|
||||
uint index_y = j + read_offset_y;
|
||||
uint check_index_y = index_y;
|
||||
uint index_x = global_index_x + lid.x * N_READS;
|
||||
if (reverse) {
|
||||
index_y = axis_size - 1 - index_y;
|
||||
}
|
||||
|
||||
// Read in SM
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
in[offset + index_y * stride + index_x + i];
|
||||
read_into[i] = in[index_y * stride + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
in[offset + index_y * stride + index_x + i];
|
||||
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
|
||||
read_into[i] = in[index_y * stride + i];
|
||||
} else {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
Op::init;
|
||||
read_into[i] = Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Read strided into registers
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] =
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||
for (int i = 0; i < n_scans; i++) {
|
||||
values[i] = read_from[i];
|
||||
}
|
||||
// Do we need the following barrier? Shouldn't all simd threads execute
|
||||
// simultaneously?
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Perform the scan
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
for (int i = 0; i < n_scans; i++) {
|
||||
values[i] = op.simd_scan(values[i]);
|
||||
values[i] = op(values[i], prefix[i]);
|
||||
prefix[i] = simd_shuffle(values[i], simd_size - 1);
|
||||
}
|
||||
|
||||
// Write to SM
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
|
||||
values[i];
|
||||
for (int i = 0; i < n_scans; i++) {
|
||||
read_from[i] = values[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write to device memory
|
||||
if (!inclusive) {
|
||||
if (check_index_y == 0) {
|
||||
if ((index_x + N_READS) < stride) {
|
||||
if ((read_offset_x + N_READS) < stride_limit) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
out[index_y * stride + i] = Op::init;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
if ((read_offset_x + i) < stride_limit) {
|
||||
out[index_y * stride + i] = Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -424,16 +472,14 @@ template <
|
||||
check_index_y += 1;
|
||||
}
|
||||
}
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] =
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
out[index_y * stride + i] = read_into[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] =
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
|
||||
out[index_y * stride + i] = read_into[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -13,15 +13,15 @@ using namespace metal;
|
||||
|
||||
#define instantiate_contiguous_scan( \
|
||||
name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("contig_scan_" #name)]] [[kernel]] void \
|
||||
template [[host_name("contig_scan_" #name)]] [[kernel]] void \
|
||||
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& axis_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint3 gid [[threadgroup_position_in_grid]], \
|
||||
uint3 gsize [[threadgroups_per_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
@ -33,10 +33,12 @@ using namespace metal;
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& axis_size [[buffer(2)]], \
|
||||
const constant size_t& stride [[buffer(3)]], \
|
||||
uint2 gid [[thread_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
const constant size_t& stride_blocks [[buffer(4)]], \
|
||||
uint3 gid [[threadgroup_position_in_grid]], \
|
||||
uint3 gsize [[threadgroups_per_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
||||
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
@ -52,51 +54,51 @@ instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSu
|
||||
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
|
||||
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
|
||||
instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
|
||||
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
|
||||
instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum, 2)
|
||||
instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
|
||||
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
|
||||
instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
|
||||
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
|
||||
instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd, 2)
|
||||
instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
|
||||
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
|
||||
instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
|
||||
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
|
||||
instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax, 2)
|
||||
instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on
|
||||
|
@ -320,3 +320,63 @@ inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
|
||||
return complex64_t(
|
||||
simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
|
||||
}
|
||||
|
||||
inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
|
||||
return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
|
||||
return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
inline bool simd_shuffle_up(bool data, uint16_t delta) {
|
||||
return simd_shuffle_up(static_cast<uint32_t>(data), delta);
|
||||
}
|
||||
|
||||
inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) {
|
||||
return complex64_t(
|
||||
simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta));
|
||||
}
|
||||
|
||||
inline uint64_t
|
||||
simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
|
||||
return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
|
||||
as_type<uint2>(data), as_type<uint2>(filling), delta));
|
||||
}
|
||||
|
||||
inline int64_t
|
||||
simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
|
||||
return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
|
||||
as_type<uint2>(data), as_type<uint2>(filling), delta));
|
||||
}
|
||||
|
||||
inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
|
||||
return simd_shuffle_and_fill_up(
|
||||
static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
|
||||
}
|
||||
|
||||
inline complex64_t simd_shuffle_and_fill_up(
|
||||
complex64_t data,
|
||||
complex64_t filling,
|
||||
uint16_t delta) {
|
||||
return complex64_t(
|
||||
simd_shuffle_and_fill_up(data.real, filling.real, delta),
|
||||
simd_shuffle_and_fill_up(data.imag, filling.imag, delta));
|
||||
}
|
||||
|
||||
inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
|
||||
return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
|
||||
}
|
||||
|
||||
inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
|
||||
return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
|
||||
}
|
||||
|
||||
inline bool simd_shuffle(bool data, uint16_t lane) {
|
||||
return simd_shuffle(static_cast<uint32_t>(data), lane);
|
||||
}
|
||||
|
||||
inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
|
||||
return complex64_t(
|
||||
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
|
||||
}
|
||||
|
@ -14,19 +14,27 @@ namespace mlx::core {
|
||||
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure contiguity
|
||||
std::vector<array> copies;
|
||||
auto in = inputs[0];
|
||||
if (!in.flags().row_contiguous) {
|
||||
if (in.flags().contiguous && in.strides()[axis_] != 0) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
in = arr_copy;
|
||||
out.move_shared_buffer(in);
|
||||
}
|
||||
|
||||
bool contiguous = in.strides()[axis_] == 1;
|
||||
@ -61,7 +69,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (contiguous) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
@ -70,7 +79,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
constexpr int simd_size = 32;
|
||||
int elements_per_simd = n_reads * simd_size;
|
||||
int thread_groups = in.size() / size;
|
||||
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (size <= n_reads * 1024) {
|
||||
thread_group_size =
|
||||
@ -82,28 +90,41 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
thread_group_size = std::min(
|
||||
thread_group_size,
|
||||
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
|
||||
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
auto tmp_grid_dims =
|
||||
get_2d_grid_dims(in.shape(), in.strides(), /** divisor= */ size);
|
||||
MTL::Size grid_dims(
|
||||
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
|
||||
MTL::Size group_dims(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
size_t stride = in.strides()[axis_];
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
size_t stride_blocks = (stride + bn - 1) / bn;
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
int tile_x = 32;
|
||||
int tile_y = 32;
|
||||
int elements_per_tile_x = tile_x * n_reads;
|
||||
int grid_y = in.size() / size / stride;
|
||||
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
|
||||
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
|
||||
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
|
||||
int n_simdgroups = bn / n_reads;
|
||||
int thread_group_size = n_simdgroups * 32;
|
||||
auto tmp_grid_dims = get_2d_grid_dims(
|
||||
in.shape(), in.strides(), /** divisor= */ size * stride);
|
||||
if (tmp_grid_dims.width * stride_blocks <= UINT_MAX) {
|
||||
tmp_grid_dims.width *= stride_blocks;
|
||||
} else {
|
||||
tmp_grid_dims.height *= stride_blocks;
|
||||
}
|
||||
MTL::Size grid_dims(
|
||||
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
|
||||
MTL::Size group_dims(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@ -107,6 +107,48 @@ MTL::Size get_2d_grid_dims(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
MTL::Size get_2d_grid_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t divisor) {
|
||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||
// divided by divisor.
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// No need to add this shape we can just remove it from the divisor.
|
||||
if (divisor % shape[i] == 0) {
|
||||
divisor /= shape[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
|
||||
if (divisor > 1) {
|
||||
if (grid_x % divisor == 0) {
|
||||
grid_x /= divisor;
|
||||
divisor = 1;
|
||||
} else if (grid_y % divisor == 0) {
|
||||
grid_y /= divisor;
|
||||
divisor = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
return MTL::Size(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive) {
|
||||
std::ostringstream op_t;
|
||||
primitive->print(op_t);
|
||||
|
@ -42,6 +42,14 @@ MTL::Size get_2d_grid_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides);
|
||||
|
||||
// Same as above but we do an implicit division with divisor.
|
||||
// Basically, equivalent to factorizing
|
||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||
MTL::Size get_2d_grid_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t divisor);
|
||||
|
||||
inline NS::String* make_string(std::ostringstream& os) {
|
||||
std::string string = os.str();
|
||||
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
|
||||
|
@ -1758,6 +1758,18 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
c_mlx = mxop(a_mlx, axis=axis)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
|
||||
for dt in [mx.int32, mx.int64]:
|
||||
mxx = a_mlx.astype(dt)
|
||||
npx = np.array(mxx)
|
||||
for op in ["cumsum", "cumprod"]:
|
||||
npop = getattr(np, op)
|
||||
mxop = getattr(mx, op)
|
||||
for axis in (None, 0, 1, 2):
|
||||
c_npy = npop(npx, axis=axis, dtype=npx.dtype)
|
||||
c_mlx = mxop(mxx, axis=axis)
|
||||
self.assertTrue(np.array_equal(c_npy, c_mlx))
|
||||
|
||||
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
|
||||
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
|
||||
mxop = getattr(mx, op)
|
||||
|
Loading…
Reference in New Issue
Block a user