mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
More jitting (#1132)
* docs + circle min size build * jit scan, arange, softmax * add sort * jit reductions * remove print * fix deps * clean includes / nits
This commit is contained in:
parent
9401507336
commit
0189ab6ab6
@ -200,7 +200,7 @@ GGUF, you can do:
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
|
||||
|
@ -35,6 +35,7 @@ make_jit_source(
|
||||
utils
|
||||
kernels/bf16.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h
|
||||
)
|
||||
make_jit_source(
|
||||
unary_ops
|
||||
@ -44,7 +45,7 @@ make_jit_source(
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(
|
||||
reduction
|
||||
reduce_utils
|
||||
kernels/atomic.h
|
||||
kernels/reduction/ops.h
|
||||
)
|
||||
@ -57,11 +58,21 @@ if (MLX_METAL_JIT)
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
|
||||
)
|
||||
make_jit_source(arange)
|
||||
make_jit_source(copy)
|
||||
make_jit_source(unary)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
make_jit_source(sort)
|
||||
make_jit_source(
|
||||
reduce
|
||||
kernels/reduction/reduce_all.h
|
||||
kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h
|
||||
)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
|
@ -229,7 +229,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduction() << metal::scatter();
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< metal::scatter();
|
||||
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
|
9
mlx/backend/metal/jit/arange.h
Normal file
9
mlx/backend/metal/jit/arange.h
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view arange_kernels = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
|
||||
constant const {1}& start,
|
||||
constant const {1}& step,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -8,14 +8,19 @@ const char* utils();
|
||||
const char* binary_ops();
|
||||
const char* unary_ops();
|
||||
const char* ternary_ops();
|
||||
const char* reduction();
|
||||
const char* reduce_utils();
|
||||
const char* gather();
|
||||
const char* scatter();
|
||||
|
||||
const char* arange();
|
||||
const char* unary();
|
||||
const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* softmax();
|
||||
const char* sort();
|
||||
const char* reduce();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
168
mlx/backend/metal/jit/reduce.h
Normal file
168
mlx/backend/metal/jit/reduce.h
Normal file
@ -0,0 +1,168 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view reduce_init_kernels = R"(
|
||||
[[kernel]] void {0}(
|
||||
device {1}* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {{
|
||||
out[tid] = {2}<{1}>::init;
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_kernels = R"(
|
||||
template [[host_name("all_{0}")]] [[kernel]] void
|
||||
all_reduce<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("colGeneral_{0}")]] [[kernel]] void
|
||||
col_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
|
||||
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
|
||||
row_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_non_atomic_kernels = R"(
|
||||
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
|
||||
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
26
mlx/backend/metal/jit/scan.h
Normal file
26
mlx/backend/metal/jit/scan.h
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view scan_kernels = R"(
|
||||
template [[host_name("contig_{0}")]] [[kernel]] void
|
||||
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
template [[host_name("strided_{0}")]] [[kernel]] void
|
||||
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[thread_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
)";
|
23
mlx/backend/metal/jit/softmax.h
Normal file
23
mlx/backend/metal/jit/softmax.h
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view softmax_kernels = R"(
|
||||
template [[host_name("block_{0}")]] [[kernel]] void
|
||||
softmax_single_row<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("looped_{0}")]] [[kernel]] void
|
||||
softmax_looped<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
81
mlx/backend/metal/jit/sort.h
Normal file
81
mlx/backend/metal/jit/sort.h
Normal file
@ -0,0 +1,81 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view block_sort_kernels = R"(
|
||||
template [[host_name("carg_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("ncarg_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("c_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("nc_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view multiblock_sort_kernels = R"(
|
||||
template [[host_name("sort_{0}")]] [[kernel]] void
|
||||
mb_block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {1}* out_vals [[buffer(1)]],
|
||||
device {2}* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("partition_{0}")]] [[kernel]] void
|
||||
mb_block_partition<{1}, {2}, true, {3}, {4}>(
|
||||
device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals [[buffer(1)]],
|
||||
const device {2}* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]);
|
||||
template [[host_name("merge_{0}")]] [[kernel]] void
|
||||
mb_block_merge<{1}, {2}, true, {3}, {4}>(
|
||||
const device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals_in [[buffer(1)]],
|
||||
const device {2}* dev_idxs_in [[buffer(2)]],
|
||||
device {1}* dev_vals_out [[buffer(3)]],
|
||||
device {2}* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
@ -3,10 +3,15 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/binary.h"
|
||||
#include "mlx/backend/metal/jit/binary_two.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/ternary.h"
|
||||
#include "mlx/backend/metal/jit/unary.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@ -20,6 +25,22 @@ std::string op_name(const array& arr) {
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source
|
||||
<< metal::utils() << metal::arange()
|
||||
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@ -121,4 +142,138 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool precise,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::softmax()
|
||||
<< fmt::format(
|
||||
softmax_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
get_type_string(precise ? float32 : out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::scan()
|
||||
<< fmt::format(
|
||||
scan_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out),
|
||||
inclusive,
|
||||
reverse);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
block_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& idx,
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
multiblock_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< fmt::format(
|
||||
reduce_init_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(kernel_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
|
||||
<< fmt::format(
|
||||
non_atomic ? reduce_non_atomic_kernels
|
||||
: reduce_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -5,6 +5,11 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@ -33,4 +38,45 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool precise,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
int bn,
|
||||
int tn);
|
||||
|
||||
MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& idx,
|
||||
int bn,
|
||||
int tn);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -8,9 +8,9 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
)
|
||||
|
||||
|
||||
set(
|
||||
KERNELS
|
||||
"arange"
|
||||
"arg_reduce"
|
||||
"conv"
|
||||
"fft"
|
||||
@ -20,31 +20,42 @@ set(
|
||||
"rms_norm"
|
||||
"layer_norm"
|
||||
"rope"
|
||||
"scan"
|
||||
"scaled_dot_product_attention"
|
||||
"softmax"
|
||||
"sort"
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_JIT)
|
||||
set(
|
||||
KERNELS
|
||||
${KERNELS}
|
||||
"arange"
|
||||
"binary"
|
||||
"binary_two"
|
||||
"unary"
|
||||
"ternary"
|
||||
"copy"
|
||||
"softmax"
|
||||
"sort"
|
||||
"scan"
|
||||
"reduce"
|
||||
)
|
||||
set(
|
||||
HEADERS
|
||||
${HEADERS}
|
||||
arange.h
|
||||
unary_ops.h
|
||||
unary.h
|
||||
binary_ops.h
|
||||
binary.h
|
||||
ternary.h
|
||||
copy.h
|
||||
softmax.h
|
||||
sort.h
|
||||
scan.h
|
||||
reduction/ops.h
|
||||
reduction/reduce_init.h
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -87,15 +98,6 @@ foreach(KERNEL ${STEEL_KERNELS})
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal)
|
||||
file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h)
|
||||
|
||||
foreach(KERNEL ${REDUCE_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||
|
9
mlx/backend/metal/kernels/arange.h
Normal file
9
mlx/backend/metal/kernels/arange.h
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
template <typename T>
|
||||
[[kernel]] void arange(
|
||||
constant const T& start,
|
||||
constant const T& step,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = start + index * step;
|
||||
}
|
@ -1,15 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void arange(
|
||||
constant const T& start,
|
||||
constant const T& step,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = start + index * step;
|
||||
}
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||
@ -18,7 +11,6 @@ template <typename T>
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
|
@ -2,11 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef MLX_METAL_JIT
|
||||
#include <metal_atomic>
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#endif
|
||||
|
||||
using namespace metal;
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __METAL__
|
||||
#if defined __METAL__ || defined MLX_METAL_JIT
|
||||
#define MTL_CONST constant
|
||||
#else
|
||||
#define MTL_CONST
|
||||
@ -11,6 +11,5 @@
|
||||
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||
|
4
mlx/backend/metal/kernels/reduce.h
Normal file
4
mlx/backend/metal/kernels/reduce.h
Normal file
@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_all.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_col.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_row.h"
|
293
mlx/backend/metal/kernels/reduce.metal
Normal file
293
mlx/backend/metal/kernels/reduce.metal
Normal file
@ -0,0 +1,293 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) \
|
||||
inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) \
|
||||
inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) \
|
||||
inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) \
|
||||
inst_f(name, uint64, uint64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) \
|
||||
type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min_, Min) \
|
||||
type_f(inst_f, max_, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint64, uint64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i_reduce_" #name)]] [[kernel]] void \
|
||||
init_reduce<otype, op>( \
|
||||
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint lid [[thread_position_in_grid]]); \
|
||||
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_med<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
// clang-format on
|
@ -1,32 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Reduce init
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
|
||||
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And)
|
||||
instantiate_init_reduce(orbool_, bool, Or) // clang-format on
|
@ -5,11 +5,7 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#ifndef MLX_METAL_JIT
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#endif
|
||||
static constant constexpr const uint8_t simd_size = 32;
|
||||
|
||||
union bool4_or_uint {
|
||||
bool4 b;
|
||||
@ -23,6 +19,7 @@ struct None {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U = bool>
|
||||
struct And {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_all(val);
|
||||
@ -60,6 +57,7 @@ struct And {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U = bool>
|
||||
struct Or {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_any(val);
|
||||
|
@ -1,11 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -139,50 +133,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
out[thread_group_id] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
|
||||
all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@ -1,11 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -52,23 +46,6 @@ template <typename T, typename U, typename Op>
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
|
||||
col_reduce_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -186,64 +163,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
|
||||
col_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on
|
8
mlx/backend/metal/kernels/reduction/reduce_init.h
Normal file
8
mlx/backend/metal/kernels/reduction/reduce_init.h
Normal file
@ -0,0 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
@ -1,74 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
||||
// clang-format on
|
@ -1,11 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small row reductions
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -123,33 +117,6 @@ template <typename T, typename U, typename Op>
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint lid [[thread_position_in_grid]]); \
|
||||
template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_med<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Large row reductions
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -318,61 +285,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) template \
|
||||
[[host_name("row_reduce_general_" #name)]] [[kernel]] void \
|
||||
row_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) template \
|
||||
[[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@ -1,14 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
|
||||
static constant constexpr const uint8_t simd_size = 32;
|
440
mlx/backend/metal/kernels/scan.h
Normal file
440
mlx/backend/metal/kernels/scan.h
Normal file
@ -0,0 +1,440 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename U>
|
||||
struct CumSum {
|
||||
static constexpr constant U init = static_cast<U>(0);
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
return simd_prefix_inclusive_sum(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
return simd_prefix_exclusive_sum(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumProd {
|
||||
static constexpr constant U init = static_cast<U>(1.0f);
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
return simd_prefix_inclusive_product(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
return simd_prefix_exclusive_product(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CumProd<bool> {
|
||||
static constexpr constant bool init = true;
|
||||
|
||||
template <typename T>
|
||||
bool operator()(bool a, T b) {
|
||||
return a & static_cast<bool>(b);
|
||||
}
|
||||
|
||||
bool simd_scan(bool x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
bool other = simd_shuffle_up(x, i);
|
||||
x &= other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
bool simd_exclusive_scan(bool x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumMax {
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return (a >= b) ? a : b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x >= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumMin {
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return (a <= b) ? a : b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x <= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_unsafe(U values[N_READS], const device T* input) {
|
||||
if (reverse) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[N_READS - i - 1] = input[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = input[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_safe(
|
||||
U values[N_READS],
|
||||
const device T* input,
|
||||
int start,
|
||||
int total,
|
||||
U init) {
|
||||
if (reverse) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[N_READS - i - 1] =
|
||||
(start + N_READS - i - 1 < total) ? input[i] : init;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = (start + i < total) ? input[i] : init;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_unsafe(U values[N_READS], device U* out) {
|
||||
if (reverse) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = values[N_READS - i - 1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
|
||||
if (reverse) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (start + N_READS - i - 1 < total) {
|
||||
out[i] = values[N_READS - i - 1];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (start + i < total) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS,
|
||||
bool inclusive,
|
||||
bool reverse>
|
||||
[[kernel]] void contiguous_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
|
||||
// Position the pointers
|
||||
in += (gid / lsize) * axis_size;
|
||||
out += (gid / lsize) * axis_size;
|
||||
|
||||
// Compute the number of simd_groups
|
||||
uint simd_groups = lsize / simd_size;
|
||||
|
||||
// Allocate memory
|
||||
U prefix = Op::init;
|
||||
U values[N_READS];
|
||||
threadgroup U simdgroup_sums[32];
|
||||
|
||||
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
|
||||
// N_READS*lsize)
|
||||
// Read block
|
||||
// Compute inclusive scan of the block
|
||||
// Compute inclusive scan per thread
|
||||
// Compute exclusive scan of thread sums in simdgroup
|
||||
// Write simdgroup sums in SM
|
||||
// Compute exclusive scan of simdgroup sums
|
||||
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
|
||||
// value
|
||||
// Write block
|
||||
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||
// Compute the block offset
|
||||
uint offset = r * lsize * N_READS + lid * N_READS;
|
||||
|
||||
// Read the values
|
||||
if (reverse) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(
|
||||
values, in + axis_size - offset - N_READS);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(
|
||||
values,
|
||||
in + axis_size - offset - N_READS,
|
||||
offset,
|
||||
axis_size,
|
||||
Op::init);
|
||||
}
|
||||
} else {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(
|
||||
values, in + offset, offset, axis_size, Op::init);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute an inclusive scan per thread
|
||||
for (int i = 1; i < N_READS; i++) {
|
||||
values[i] = op(values[i], values[i - 1]);
|
||||
}
|
||||
|
||||
// Compute exclusive scan of thread sums
|
||||
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
|
||||
|
||||
// Write simdgroup_sums to SM
|
||||
if (simd_lane_id == simd_size - 1) {
|
||||
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Compute exclusive scan of simdgroup_sums
|
||||
if (simd_group_id == 0) {
|
||||
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
|
||||
simdgroup_sums[simd_lane_id] = prev_simdgroup;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Compute the output
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = op(values[i], prefix);
|
||||
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
|
||||
values[i] = op(values[i], prev_thread);
|
||||
}
|
||||
|
||||
// Write the values
|
||||
if (reverse) {
|
||||
if (inclusive) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(
|
||||
values, out + axis_size - offset - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
out[axis_size - 1] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(
|
||||
values, out + axis_size - offset - 1 - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values,
|
||||
out + axis_size - offset - 1 - N_READS,
|
||||
offset + 1,
|
||||
axis_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values, out + offset, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
out[0] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values, out + offset + 1, offset + 1, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Share the prefix
|
||||
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
||||
simdgroup_sums[0] = values[N_READS - 1];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
prefix = simdgroup_sums[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS,
|
||||
bool inclusive,
|
||||
bool reverse>
|
||||
[[kernel]] void strided_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]) {
|
||||
Op op;
|
||||
|
||||
// Allocate memory
|
||||
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
|
||||
U values[N_READS];
|
||||
U prefix[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
prefix[i] = Op::init;
|
||||
}
|
||||
|
||||
// Compute offsets
|
||||
int offset = gid.y * axis_size * stride;
|
||||
int global_index_x = gid.x * lsize.y * N_READS;
|
||||
|
||||
for (uint j = 0; j < axis_size; j += simd_size) {
|
||||
// Calculate the indices for the current thread
|
||||
uint index_y = j + lid.y;
|
||||
uint check_index_y = index_y;
|
||||
uint index_x = global_index_x + lid.x * N_READS;
|
||||
if (reverse) {
|
||||
index_y = axis_size - 1 - index_y;
|
||||
}
|
||||
|
||||
// Read in SM
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
in[offset + index_y * stride + index_x + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
in[offset + index_y * stride + index_x + i];
|
||||
} else {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Read strided into registers
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] =
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||
}
|
||||
// Do we need the following barrier? Shouldn't all simd threads execute
|
||||
// simultaneously?
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Perform the scan
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = op.simd_scan(values[i]);
|
||||
values[i] = op(values[i], prefix[i]);
|
||||
prefix[i] = simd_shuffle(values[i], simd_size - 1);
|
||||
}
|
||||
|
||||
// Write to SM
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
|
||||
values[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write to device memory
|
||||
if (!inclusive) {
|
||||
if (check_index_y == 0) {
|
||||
if ((index_x + N_READS) < stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (reverse) {
|
||||
index_y -= 1;
|
||||
check_index_y += 1;
|
||||
} else {
|
||||
index_y += 1;
|
||||
check_index_y += 1;
|
||||
}
|
||||
}
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] =
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] =
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,455 +1,19 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_math>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
// clang-format off
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename U>
|
||||
struct CumSum {
|
||||
static constexpr constant U init = static_cast<U>(0);
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
return simd_prefix_inclusive_sum(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
return simd_prefix_exclusive_sum(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumProd {
|
||||
static constexpr constant U init = static_cast<U>(1.0f);
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
return simd_prefix_inclusive_product(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
return simd_prefix_exclusive_product(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CumProd<bool> {
|
||||
static constexpr constant bool init = true;
|
||||
|
||||
template <typename T>
|
||||
bool operator()(bool a, T b) {
|
||||
return a & static_cast<bool>(b);
|
||||
}
|
||||
|
||||
bool simd_scan(bool x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
bool other = simd_shuffle_up(x, i);
|
||||
x &= other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
bool simd_exclusive_scan(bool x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumMax {
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return (a >= b) ? a : b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x >= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumMin {
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return (a <= b) ? a : b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x <= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_unsafe(U values[N_READS], const device T* input) {
|
||||
if (reverse) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[N_READS - i - 1] = input[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = input[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_safe(
|
||||
U values[N_READS],
|
||||
const device T* input,
|
||||
int start,
|
||||
int total,
|
||||
U init) {
|
||||
if (reverse) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[N_READS - i - 1] =
|
||||
(start + N_READS - i - 1 < total) ? input[i] : init;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = (start + i < total) ? input[i] : init;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_unsafe(U values[N_READS], device U* out) {
|
||||
if (reverse) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = values[N_READS - i - 1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
|
||||
if (reverse) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (start + N_READS - i - 1 < total) {
|
||||
out[i] = values[N_READS - i - 1];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (start + i < total) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS,
|
||||
bool inclusive,
|
||||
bool reverse>
|
||||
[[kernel]] void contiguous_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
|
||||
// Position the pointers
|
||||
in += (gid / lsize) * axis_size;
|
||||
out += (gid / lsize) * axis_size;
|
||||
|
||||
// Compute the number of simd_groups
|
||||
uint simd_groups = lsize / simd_size;
|
||||
|
||||
// Allocate memory
|
||||
U prefix = Op::init;
|
||||
U values[N_READS];
|
||||
threadgroup U simdgroup_sums[32];
|
||||
|
||||
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
|
||||
// N_READS*lsize)
|
||||
// Read block
|
||||
// Compute inclusive scan of the block
|
||||
// Compute inclusive scan per thread
|
||||
// Compute exclusive scan of thread sums in simdgroup
|
||||
// Write simdgroup sums in SM
|
||||
// Compute exclusive scan of simdgroup sums
|
||||
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
|
||||
// value
|
||||
// Write block
|
||||
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||
// Compute the block offset
|
||||
uint offset = r * lsize * N_READS + lid * N_READS;
|
||||
|
||||
// Read the values
|
||||
if (reverse) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(
|
||||
values, in + axis_size - offset - N_READS);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(
|
||||
values,
|
||||
in + axis_size - offset - N_READS,
|
||||
offset,
|
||||
axis_size,
|
||||
Op::init);
|
||||
}
|
||||
} else {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(
|
||||
values, in + offset, offset, axis_size, Op::init);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute an inclusive scan per thread
|
||||
for (int i = 1; i < N_READS; i++) {
|
||||
values[i] = op(values[i], values[i - 1]);
|
||||
}
|
||||
|
||||
// Compute exclusive scan of thread sums
|
||||
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
|
||||
|
||||
// Write simdgroup_sums to SM
|
||||
if (simd_lane_id == simd_size - 1) {
|
||||
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Compute exclusive scan of simdgroup_sums
|
||||
if (simd_group_id == 0) {
|
||||
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
|
||||
simdgroup_sums[simd_lane_id] = prev_simdgroup;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Compute the output
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = op(values[i], prefix);
|
||||
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
|
||||
values[i] = op(values[i], prev_thread);
|
||||
}
|
||||
|
||||
// Write the values
|
||||
if (reverse) {
|
||||
if (inclusive) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(
|
||||
values, out + axis_size - offset - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
out[axis_size - 1] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(
|
||||
values, out + axis_size - offset - 1 - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values,
|
||||
out + axis_size - offset - 1 - N_READS,
|
||||
offset + 1,
|
||||
axis_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values, out + offset, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
out[0] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(
|
||||
values, out + offset + 1, offset + 1, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Share the prefix
|
||||
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
||||
simdgroup_sums[0] = values[N_READS - 1];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
prefix = simdgroup_sums[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS,
|
||||
bool inclusive,
|
||||
bool reverse>
|
||||
[[kernel]] void strided_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]) {
|
||||
Op op;
|
||||
|
||||
// Allocate memory
|
||||
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
|
||||
U values[N_READS];
|
||||
U prefix[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
prefix[i] = Op::init;
|
||||
}
|
||||
|
||||
// Compute offsets
|
||||
int offset = gid.y * axis_size * stride;
|
||||
int global_index_x = gid.x * lsize.y * N_READS;
|
||||
|
||||
for (uint j = 0; j < axis_size; j += simd_size) {
|
||||
// Calculate the indices for the current thread
|
||||
uint index_y = j + lid.y;
|
||||
uint check_index_y = index_y;
|
||||
uint index_x = global_index_x + lid.x * N_READS;
|
||||
if (reverse) {
|
||||
index_y = axis_size - 1 - index_y;
|
||||
}
|
||||
|
||||
// Read in SM
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
in[offset + index_y * stride + index_x + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
in[offset + index_y * stride + index_x + i];
|
||||
} else {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||
Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Read strided into registers
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] =
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||
}
|
||||
// Do we need the following barrier? Shouldn't all simd threads execute
|
||||
// simultaneously?
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Perform the scan
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
values[i] = op.simd_scan(values[i]);
|
||||
values[i] = op(values[i], prefix[i]);
|
||||
prefix[i] = simd_shuffle(values[i], simd_size - 1);
|
||||
}
|
||||
|
||||
// Write to SM
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
|
||||
values[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write to device memory
|
||||
if (!inclusive) {
|
||||
if (check_index_y == 0) {
|
||||
if ((index_x + N_READS) < stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (reverse) {
|
||||
index_y -= 1;
|
||||
check_index_y += 1;
|
||||
} else {
|
||||
index_y += 1;
|
||||
check_index_y += 1;
|
||||
}
|
||||
}
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] =
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] =
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/scan.h"
|
||||
|
||||
#define instantiate_contiguous_scan( \
|
||||
name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \
|
||||
template [[host_name("contig_scan_" #name)]] [[kernel]] void \
|
||||
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
@ -474,7 +38,6 @@ template <
|
||||
uint2 lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
||||
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
@ -483,9 +46,8 @@ template <
|
||||
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on
|
||||
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
|
||||
|
||||
// clang-format off
|
||||
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
||||
|
190
mlx/backend/metal/kernels/softmax.h
Normal file
190
mlx/backend/metal/kernels/softmax.h
Normal file
@ -0,0 +1,190 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename T>
|
||||
inline T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return fast::exp(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_single_row(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<AccT>::finite_min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
maxval = simd_max(maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[0] = maxval;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
AccT exp_x = softmax_exp(ld[i] - maxval);
|
||||
ld[i] = exp_x;
|
||||
normalizer += exp_x;
|
||||
}
|
||||
normalizer = simd_sum(normalizer);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[0] = normalizer;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = 1 / local_normalizer[0];
|
||||
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
AccT vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = AccT(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
||||
}
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
|
||||
// lsize) parts. We need to combine them.
|
||||
// 1. We start by finding the max across simd groups
|
||||
// 2. We then change the partial normalizers to account for a possible
|
||||
// change in max
|
||||
// 3. We sum all normalizers
|
||||
prevmax = maxval;
|
||||
maxval = simd_max(maxval);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
normalizer = simd_sum(normalizer);
|
||||
|
||||
// Now the normalizer and max value is correct for each simdgroup. We write
|
||||
// them shared memory and combine them.
|
||||
prevmax = maxval;
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Finally given the normalizer and max value we can directly write the
|
||||
// softmax output
|
||||
out += gid * axis_size;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (offset + i < axis_size) {
|
||||
out[offset + i] =
|
||||
T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,205 +1,18 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return fast::exp(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_single_row(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<AccT>::finite_min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
maxval = simd_max(maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[0] = maxval;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
AccT exp_x = softmax_exp(ld[i] - maxval);
|
||||
ld[i] = exp_x;
|
||||
normalizer += exp_x;
|
||||
}
|
||||
normalizer = simd_sum(normalizer);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[0] = normalizer;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = 1 / local_normalizer[0];
|
||||
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
AccT vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = AccT(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
||||
}
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
|
||||
// lsize) parts. We need to combine them.
|
||||
// 1. We start by finding the max across simd groups
|
||||
// 2. We then change the partial normalizers to account for a possible
|
||||
// change in max
|
||||
// 3. We sum all normalizers
|
||||
prevmax = maxval;
|
||||
maxval = simd_max(maxval);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
normalizer = simd_sum(normalizer);
|
||||
|
||||
// Now the normalizer and max value is correct for each simdgroup. We write
|
||||
// them shared memory and combine them.
|
||||
prevmax = maxval;
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Finally given the normalizer and max value we can directly write the
|
||||
// softmax output
|
||||
out += gid * axis_size;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (offset + i < axis_size) {
|
||||
out[offset + i] =
|
||||
T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#include "mlx/backend/metal/kernels/softmax.h"
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
template [[host_name("block_softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
@ -208,7 +21,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
@ -220,7 +33,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
|
||||
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
@ -229,7 +42,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
|
||||
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
@ -240,7 +53,6 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
instantiate_softmax(float32, float)
|
||||
instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
|
674
mlx/backend/metal/kernels/sort.h
Normal file
674
mlx/backend/metal/kernels/sort.h
Normal file
@ -0,0 +1,674 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Based on GPU merge sort algorithm at
|
||||
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Thread-level sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
|
||||
T w = a;
|
||||
a = b;
|
||||
b = w;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct LessThan {
|
||||
static constexpr constant T init = Limits<T>::max;
|
||||
|
||||
METAL_FUNC bool operator()(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp>
|
||||
struct ThreadSort {
|
||||
static METAL_FUNC void sort(
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
CompareOp op;
|
||||
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for (short i = 0; i < N_PER_THREAD; ++i) {
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||
if (op(vals[j + 1], vals[j])) {
|
||||
thread_swap(vals[j + 1], vals[j]);
|
||||
thread_swap(idxs[j + 1], idxs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Threadgroup-level sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp>
|
||||
struct BlockMergeSort {
|
||||
using thread_sort_t =
|
||||
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||
static METAL_FUNC int merge_partition(
|
||||
const threadgroup val_t* As,
|
||||
const threadgroup val_t* Bs,
|
||||
short A_sz,
|
||||
short B_sz,
|
||||
short sort_md) {
|
||||
CompareOp op;
|
||||
|
||||
short A_st = max(0, sort_md - B_sz);
|
||||
short A_ed = min(sort_md, A_sz);
|
||||
|
||||
while (A_st < A_ed) {
|
||||
short md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if (op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
}
|
||||
|
||||
static METAL_FUNC void merge_step(
|
||||
const threadgroup val_t* As,
|
||||
const threadgroup val_t* Bs,
|
||||
const threadgroup idx_t* As_idx,
|
||||
const threadgroup idx_t* Bs_idx,
|
||||
short A_sz,
|
||||
short B_sz,
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
CompareOp op;
|
||||
short a_idx = 0;
|
||||
short b_idx = 0;
|
||||
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
auto a = As[a_idx];
|
||||
auto b = Bs[b_idx];
|
||||
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||
|
||||
vals[i] = pred ? b : a;
|
||||
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
||||
|
||||
b_idx += short(pred);
|
||||
a_idx += short(!pred);
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void sort(
|
||||
threadgroup val_t* tgp_vals [[threadgroup(0)]],
|
||||
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
||||
int size_sorted_axis,
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Get thread location
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
|
||||
// Load from shared memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
thread_vals[i] = tgp_vals[idx + i];
|
||||
if (ARG_SORT) {
|
||||
thread_idxs[i] = tgp_idxs[idx + i];
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread sort
|
||||
if (idx < size_sorted_axis) {
|
||||
thread_sort_t::sort(thread_vals, thread_idxs);
|
||||
}
|
||||
|
||||
// Do merges using threadgroup memory
|
||||
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
|
||||
merge_threads *= 2) {
|
||||
// Update threadgroup memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_threads;
|
||||
int merge_lane = lid.x % merge_threads;
|
||||
|
||||
int sort_sz = N_PER_THREAD * merge_threads;
|
||||
int sort_st = N_PER_THREAD * merge_threads * merge_group;
|
||||
|
||||
// As = tgp_vals[A_st:A_ed] is sorted
|
||||
// Bs = tgp_vals[B_st:B_ed] is sorted
|
||||
int A_st = sort_st;
|
||||
int A_ed = sort_st + sort_sz / 2;
|
||||
int B_st = sort_st + sort_sz / 2;
|
||||
int B_ed = sort_st + sort_sz;
|
||||
|
||||
const threadgroup val_t* As = tgp_vals + A_st;
|
||||
const threadgroup val_t* Bs = tgp_vals + B_st;
|
||||
int A_sz = A_ed - A_st;
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Find a partition of merge elements
|
||||
// Ci = merge(As[partition:], Bs[sort_md - partition:])
|
||||
// of size N_PER_THREAD for each merge lane i
|
||||
// C = [Ci] is sorted
|
||||
int sort_md = N_PER_THREAD * merge_lane;
|
||||
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
|
||||
|
||||
As += partition;
|
||||
Bs += sort_md - partition;
|
||||
|
||||
A_sz -= partition;
|
||||
B_sz -= sort_md - partition;
|
||||
|
||||
const threadgroup idx_t* As_idx =
|
||||
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||
const threadgroup idx_t* Bs_idx =
|
||||
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||
|
||||
// Merge starting at the partition and store results in thread registers
|
||||
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
|
||||
}
|
||||
|
||||
// Write out to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Kernel sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<T>>
|
||||
struct KernelMergeSort {
|
||||
using val_t = T;
|
||||
using idx_t = uint;
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
const device T* inp,
|
||||
device U* out,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
const constant int& stride_segment_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// tid.y tells us the segment index
|
||||
inp += tid.y * stride_segment_axis;
|
||||
out += tid.y * stride_segment_axis;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
|
||||
: val_t(CompareOp::init);
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Sort elements within the block
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
|
||||
if (ARG_SORT) {
|
||||
out[i * stride_sorted_axis] = tgp_idxs[i];
|
||||
} else {
|
||||
out[i * stride_sorted_axis] = tgp_vals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
if (ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
}
|
||||
|
||||
constant constexpr const int zero_helper = 0;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out += block_idx;
|
||||
|
||||
if (ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
struct KernelMultiBlockMergeSort {
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
const device val_t* inp,
|
||||
device val_t* out_vals,
|
||||
device idx_t* out_idxs,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// tid.y tells us the segment index
|
||||
int base_idx = tid.x * N_PER_BLOCK;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
|
||||
: val_t(CompareOp::init);
|
||||
tgp_idxs[i] = idx;
|
||||
}
|
||||
|
||||
// Sort elements within the block
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if (idx < size_sorted_axis) {
|
||||
out_vals[idx] = tgp_vals[i];
|
||||
out_idxs[idx] = tgp_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC int merge_partition(
|
||||
const device val_t* As,
|
||||
const device val_t* Bs,
|
||||
int A_sz,
|
||||
int B_sz,
|
||||
int sort_md) {
|
||||
CompareOp op;
|
||||
|
||||
int A_st = max(0, sort_md - B_sz);
|
||||
int A_ed = min(sort_md, A_sz);
|
||||
|
||||
while (A_st < A_ed) {
|
||||
int md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if (op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
|
||||
const device val_t* inp [[buffer(0)]],
|
||||
device val_t* out_vals [[buffer(1)]],
|
||||
device idx_t* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD>;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out_vals += tid.y * size_sorted_axis;
|
||||
out_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out_vals,
|
||||
out_idxs,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||
mb_block_partition(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD>;
|
||||
|
||||
block_partitions += tid.y * tgp_dims.x;
|
||||
dev_vals += tid.y * size_sorted_axis;
|
||||
dev_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_tiles;
|
||||
int merge_lane = lid.x % merge_tiles;
|
||||
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
|
||||
int A_st = min(size_sorted_axis, sort_st);
|
||||
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
int B_st = A_ed;
|
||||
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||
|
||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||
int partition = sort_kernel::merge_partition(
|
||||
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
|
||||
|
||||
block_partitions[lid.x] = A_st + partition;
|
||||
}
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||
mb_block_merge(
|
||||
const device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals_in [[buffer(1)]],
|
||||
const device idx_t* dev_idxs_in [[buffer(2)]],
|
||||
device val_t* dev_vals_out [[buffer(3)]],
|
||||
device idx_t* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
using block_sort_t = typename sort_kernel::block_merge_sort_t;
|
||||
|
||||
block_partitions += tid.y * (num_tiles + 1);
|
||||
dev_vals_in += tid.y * size_sorted_axis;
|
||||
dev_idxs_in += tid.y * size_sorted_axis;
|
||||
dev_vals_out += tid.y * size_sorted_axis;
|
||||
dev_idxs_out += tid.y * size_sorted_axis;
|
||||
|
||||
int block_idx = tid.x;
|
||||
int merge_group = block_idx / merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
|
||||
|
||||
int A_st = block_partitions[block_idx + 0];
|
||||
int A_ed = block_partitions[block_idx + 1];
|
||||
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
|
||||
int B_ed = min(
|
||||
size_sorted_axis,
|
||||
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||
|
||||
if ((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
B_ed = min(size_sorted_axis, sort_st + sort_sz);
|
||||
}
|
||||
|
||||
int A_sz = A_ed - A_st;
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Load from global memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
if (idx < (A_sz + B_sz)) {
|
||||
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
|
||||
: dev_vals_in[B_st + idx - A_sz];
|
||||
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
|
||||
: dev_idxs_in[B_st + idx - A_sz];
|
||||
} else {
|
||||
thread_vals[i] = CompareOp::init;
|
||||
thread_idxs[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Write to shared memory
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
tgp_vals[idx] = thread_vals[i];
|
||||
tgp_idxs[idx] = thread_idxs[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Merge
|
||||
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
||||
|
||||
int A_st_local = block_sort_t::merge_partition(
|
||||
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
|
||||
int A_ed_local = A_sz;
|
||||
|
||||
int B_st_local = sort_md_local - A_st_local;
|
||||
int B_ed_local = B_sz;
|
||||
|
||||
int A_sz_local = A_ed_local - A_st_local;
|
||||
int B_sz_local = B_ed_local - B_st_local;
|
||||
|
||||
// Do merge
|
||||
block_sort_t::merge_step(
|
||||
tgp_vals + A_st_local,
|
||||
tgp_vals + A_ed_local + B_st_local,
|
||||
tgp_idxs + A_st_local,
|
||||
tgp_idxs + A_ed_local + B_st_local,
|
||||
A_sz_local,
|
||||
B_sz_local,
|
||||
thread_vals,
|
||||
thread_idxs);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Write output
|
||||
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
|
||||
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if (idx < size_sorted_axis) {
|
||||
dev_vals_out[idx] = tgp_vals[i];
|
||||
dev_idxs_out[idx] = tgp_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
@ -1,392 +1,16 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Based on GPU merge sort algorithm at
|
||||
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Thread-level sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
|
||||
T w = a;
|
||||
a = b;
|
||||
b = w;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct LessThan {
|
||||
static constexpr constant T init = Limits<T>::max;
|
||||
|
||||
METAL_FUNC bool operator()(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp>
|
||||
struct ThreadSort {
|
||||
static METAL_FUNC void sort(
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
CompareOp op;
|
||||
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for (short i = 0; i < N_PER_THREAD; ++i) {
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||
if (op(vals[j + 1], vals[j])) {
|
||||
thread_swap(vals[j + 1], vals[j]);
|
||||
thread_swap(idxs[j + 1], idxs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Threadgroup-level sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp>
|
||||
struct BlockMergeSort {
|
||||
using thread_sort_t =
|
||||
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||
static METAL_FUNC int merge_partition(
|
||||
const threadgroup val_t* As,
|
||||
const threadgroup val_t* Bs,
|
||||
short A_sz,
|
||||
short B_sz,
|
||||
short sort_md) {
|
||||
CompareOp op;
|
||||
|
||||
short A_st = max(0, sort_md - B_sz);
|
||||
short A_ed = min(sort_md, A_sz);
|
||||
|
||||
while (A_st < A_ed) {
|
||||
short md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if (op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
}
|
||||
|
||||
static METAL_FUNC void merge_step(
|
||||
const threadgroup val_t* As,
|
||||
const threadgroup val_t* Bs,
|
||||
const threadgroup idx_t* As_idx,
|
||||
const threadgroup idx_t* Bs_idx,
|
||||
short A_sz,
|
||||
short B_sz,
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
CompareOp op;
|
||||
short a_idx = 0;
|
||||
short b_idx = 0;
|
||||
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
auto a = As[a_idx];
|
||||
auto b = Bs[b_idx];
|
||||
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||
|
||||
vals[i] = pred ? b : a;
|
||||
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
||||
|
||||
b_idx += short(pred);
|
||||
a_idx += short(!pred);
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void sort(
|
||||
threadgroup val_t* tgp_vals [[threadgroup(0)]],
|
||||
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
||||
int size_sorted_axis,
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Get thread location
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
|
||||
// Load from shared memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
thread_vals[i] = tgp_vals[idx + i];
|
||||
if (ARG_SORT) {
|
||||
thread_idxs[i] = tgp_idxs[idx + i];
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread sort
|
||||
if (idx < size_sorted_axis) {
|
||||
thread_sort_t::sort(thread_vals, thread_idxs);
|
||||
}
|
||||
|
||||
// Do merges using threadgroup memory
|
||||
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
|
||||
merge_threads *= 2) {
|
||||
// Update threadgroup memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_threads;
|
||||
int merge_lane = lid.x % merge_threads;
|
||||
|
||||
int sort_sz = N_PER_THREAD * merge_threads;
|
||||
int sort_st = N_PER_THREAD * merge_threads * merge_group;
|
||||
|
||||
// As = tgp_vals[A_st:A_ed] is sorted
|
||||
// Bs = tgp_vals[B_st:B_ed] is sorted
|
||||
int A_st = sort_st;
|
||||
int A_ed = sort_st + sort_sz / 2;
|
||||
int B_st = sort_st + sort_sz / 2;
|
||||
int B_ed = sort_st + sort_sz;
|
||||
|
||||
const threadgroup val_t* As = tgp_vals + A_st;
|
||||
const threadgroup val_t* Bs = tgp_vals + B_st;
|
||||
int A_sz = A_ed - A_st;
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Find a partition of merge elements
|
||||
// Ci = merge(As[partition:], Bs[sort_md - partition:])
|
||||
// of size N_PER_THREAD for each merge lane i
|
||||
// C = [Ci] is sorted
|
||||
int sort_md = N_PER_THREAD * merge_lane;
|
||||
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
|
||||
|
||||
As += partition;
|
||||
Bs += sort_md - partition;
|
||||
|
||||
A_sz -= partition;
|
||||
B_sz -= sort_md - partition;
|
||||
|
||||
const threadgroup idx_t* As_idx =
|
||||
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||
const threadgroup idx_t* Bs_idx =
|
||||
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||
|
||||
// Merge starting at the partition and store results in thread registers
|
||||
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
|
||||
}
|
||||
|
||||
// Write out to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Kernel sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<T>>
|
||||
struct KernelMergeSort {
|
||||
using val_t = T;
|
||||
using idx_t = uint;
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
const device T* inp,
|
||||
device U* out,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
const constant int& stride_segment_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// tid.y tells us the segment index
|
||||
inp += tid.y * stride_segment_axis;
|
||||
out += tid.y * stride_segment_axis;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
|
||||
: val_t(CompareOp::init);
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Sort elements within the block
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
|
||||
if (ARG_SORT) {
|
||||
out[i * stride_sorted_axis] = tgp_idxs[i];
|
||||
} else {
|
||||
out[i * stride_sorted_axis] = tgp_vals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
if (ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
}
|
||||
|
||||
constant constexpr const int zero_helper = 0;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out += block_idx;
|
||||
|
||||
if (ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
#include "mlx/backend/metal/kernels/sort.h"
|
||||
|
||||
#define instantiate_block_sort( \
|
||||
name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \
|
||||
template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
@ -396,8 +20,8 @@ template <
|
||||
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
|
||||
"_nc")]] [[kernel]] void \
|
||||
template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
|
||||
)]] [[kernel]] void \
|
||||
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
@ -411,18 +35,16 @@ template <
|
||||
|
||||
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort( \
|
||||
arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||
arg_block_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||
|
||||
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort( \
|
||||
block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
||||
_block_sort, itname, itype, itname, itype, false, bn, tn)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_block_sort_tn(itname, itype, bn) \
|
||||
instantiate_block_sort_base(itname, itype, bn, 8) \
|
||||
instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on
|
||||
instantiate_arg_block_sort_base(itname, itype, bn, 8)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_block_sort_bn(itname, itype) \
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 256) \
|
||||
@ -436,321 +58,18 @@ instantiate_block_sort_bn(int16, int16_t)
|
||||
instantiate_block_sort_bn(int32, int32_t)
|
||||
instantiate_block_sort_bn(float16, half)
|
||||
instantiate_block_sort_bn(float32, float)
|
||||
instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on
|
||||
// clang-format off
|
||||
instantiate_block_sort_bn(bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_block_sort_long(itname, itype) \
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 256)
|
||||
|
||||
instantiate_block_sort_long(uint64, uint64_t)
|
||||
instantiate_block_sort_long(int64, int64_t) // clang-format on
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multi block merge sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
struct KernelMultiBlockMergeSort {
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
const device val_t* inp,
|
||||
device val_t* out_vals,
|
||||
device idx_t* out_idxs,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// tid.y tells us the segment index
|
||||
int base_idx = tid.x * N_PER_BLOCK;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
|
||||
: val_t(CompareOp::init);
|
||||
tgp_idxs[i] = idx;
|
||||
}
|
||||
|
||||
// Sort elements within the block
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if (idx < size_sorted_axis) {
|
||||
out_vals[idx] = tgp_vals[i];
|
||||
out_idxs[idx] = tgp_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC int merge_partition(
|
||||
const device val_t* As,
|
||||
const device val_t* Bs,
|
||||
int A_sz,
|
||||
int B_sz,
|
||||
int sort_md) {
|
||||
CompareOp op;
|
||||
|
||||
int A_st = max(0, sort_md - B_sz);
|
||||
int A_ed = min(sort_md, A_sz);
|
||||
|
||||
while (A_st < A_ed) {
|
||||
int md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if (op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
|
||||
const device val_t* inp [[buffer(0)]],
|
||||
device val_t* out_vals [[buffer(1)]],
|
||||
device idx_t* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD>;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out_vals += tid.y * size_sorted_axis;
|
||||
out_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out_vals,
|
||||
out_idxs,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||
mb_block_partition(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD>;
|
||||
|
||||
block_partitions += tid.y * tgp_dims.x;
|
||||
dev_vals += tid.y * size_sorted_axis;
|
||||
dev_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_tiles;
|
||||
int merge_lane = lid.x % merge_tiles;
|
||||
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
|
||||
int A_st = min(size_sorted_axis, sort_st);
|
||||
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
int B_st = A_ed;
|
||||
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||
|
||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||
int partition = sort_kernel::merge_partition(
|
||||
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
|
||||
|
||||
block_partitions[lid.x] = A_st + partition;
|
||||
}
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||
mb_block_merge(
|
||||
const device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals_in [[buffer(1)]],
|
||||
const device idx_t* dev_idxs_in [[buffer(2)]],
|
||||
device val_t* dev_vals_out [[buffer(3)]],
|
||||
device idx_t* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
using block_sort_t = typename sort_kernel::block_merge_sort_t;
|
||||
|
||||
block_partitions += tid.y * (num_tiles + 1);
|
||||
dev_vals_in += tid.y * size_sorted_axis;
|
||||
dev_idxs_in += tid.y * size_sorted_axis;
|
||||
dev_vals_out += tid.y * size_sorted_axis;
|
||||
dev_idxs_out += tid.y * size_sorted_axis;
|
||||
|
||||
int block_idx = tid.x;
|
||||
int merge_group = block_idx / merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
|
||||
|
||||
int A_st = block_partitions[block_idx + 0];
|
||||
int A_ed = block_partitions[block_idx + 1];
|
||||
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
|
||||
int B_ed = min(
|
||||
size_sorted_axis,
|
||||
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||
|
||||
if ((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
B_ed = min(size_sorted_axis, sort_st + sort_sz);
|
||||
}
|
||||
|
||||
int A_sz = A_ed - A_st;
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Load from global memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
if (idx < (A_sz + B_sz)) {
|
||||
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
|
||||
: dev_vals_in[B_st + idx - A_sz];
|
||||
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
|
||||
: dev_idxs_in[B_st + idx - A_sz];
|
||||
} else {
|
||||
thread_vals[i] = CompareOp::init;
|
||||
thread_idxs[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Write to shared memory
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
tgp_vals[idx] = thread_vals[i];
|
||||
tgp_idxs[idx] = thread_idxs[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Merge
|
||||
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
||||
|
||||
int A_st_local = block_sort_t::merge_partition(
|
||||
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
|
||||
int A_ed_local = A_sz;
|
||||
|
||||
int B_st_local = sort_md_local - A_st_local;
|
||||
int B_ed_local = B_sz;
|
||||
|
||||
int A_sz_local = A_ed_local - A_st_local;
|
||||
int B_sz_local = B_ed_local - B_st_local;
|
||||
|
||||
// Do merge
|
||||
block_sort_t::merge_step(
|
||||
tgp_vals + A_st_local,
|
||||
tgp_vals + A_ed_local + B_st_local,
|
||||
tgp_idxs + A_st_local,
|
||||
tgp_idxs + A_ed_local + B_st_local,
|
||||
A_sz_local,
|
||||
B_sz_local,
|
||||
thread_vals,
|
||||
thread_idxs);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Write output
|
||||
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
|
||||
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if (idx < size_sorted_axis) {
|
||||
dev_vals_out[idx] = tgp_vals[i];
|
||||
dev_idxs_out[idx] = tgp_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
instantiate_block_sort_long(int64, int64_t)
|
||||
|
||||
#define instantiate_multi_block_sort( \
|
||||
vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \
|
||||
template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device vtype* inp [[buffer(0)]], \
|
||||
@ -763,7 +82,7 @@ mb_block_merge(
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \
|
||||
template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype * block_partitions [[buffer(0)]], \
|
||||
@ -774,7 +93,7 @@ mb_block_merge(
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \
|
||||
template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device itype* block_partitions [[buffer(0)]], \
|
||||
@ -788,7 +107,6 @@ mb_block_merge(
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||
|
||||
@ -800,9 +118,8 @@ instantiate_multi_block_sort_base(int16, int16_t)
|
||||
instantiate_multi_block_sort_base(int32, int32_t)
|
||||
instantiate_multi_block_sort_base(float16, half)
|
||||
instantiate_multi_block_sort_base(float32, float)
|
||||
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on
|
||||
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <metal_math>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/complex.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
|
||||
typedef half float16_t;
|
||||
|
||||
|
@ -5,6 +5,13 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@ -15,31 +22,84 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
const array&,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
const array&,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
const array&,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool,
|
||||
bool,
|
||||
const array&,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const array&,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const array&,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@ -27,7 +28,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel("arange" + type_to_name(out));
|
||||
auto kernel = get_arange_kernel(d, "arange" + type_to_name(out), out);
|
||||
size_t nthreads = out.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@ -40,9 +41,12 @@ void all_reduce_dispatch(
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||
auto kernel = (is_out_64b_int)
|
||||
? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in))
|
||||
: d.get_kernel("all_reduce_" + op_name + type_to_name(in));
|
||||
std::string kernel_name = "all";
|
||||
if (is_out_64b_int) {
|
||||
kernel_name += "NoAtomics";
|
||||
}
|
||||
kernel_name += "_reduce_" + op_name + type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kernel_name, in, out);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -158,18 +162,20 @@ void row_reduce_general_dispatch(
|
||||
bool is_med = non_row_reductions * reduction_size <= 256;
|
||||
is_out_64b_int &= !is_small && !is_med;
|
||||
|
||||
std::string small_desc = "_";
|
||||
if (is_small) {
|
||||
small_desc = "_small_";
|
||||
std::string small_desc;
|
||||
if (is_out_64b_int) {
|
||||
small_desc = "NoAtomics";
|
||||
} else if (is_small) {
|
||||
small_desc = "Small";
|
||||
} else if (is_med) {
|
||||
small_desc = "_med_";
|
||||
small_desc = "Med";
|
||||
} else {
|
||||
small_desc = "";
|
||||
}
|
||||
kname << "rowGeneral" << small_desc << "_reduce_" << op_name
|
||||
<< type_to_name(in);
|
||||
|
||||
small_desc = is_out_64b_int ? "_no_atomics_" : small_desc;
|
||||
|
||||
kname << "row_reduce_general" << small_desc << op_name << type_to_name(in);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), in, out);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Get dispatch grid dims
|
||||
@ -335,8 +341,8 @@ void strided_reduce_general_dispatch(
|
||||
// Specialize for small dims
|
||||
if (reduction_size * non_col_reductions < 16) {
|
||||
// Select kernel
|
||||
auto kernel =
|
||||
d.get_kernel("col_reduce_small_" + op_name + type_to_name(in));
|
||||
auto kernel = get_reduce_kernel(
|
||||
d, "colSmall_reduce_" + op_name + type_to_name(in), in, out);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Select block dims
|
||||
@ -373,10 +379,12 @@ void strided_reduce_general_dispatch(
|
||||
|
||||
// Select kernel
|
||||
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||
auto kernel = (is_out_64b_int)
|
||||
? d.get_kernel(
|
||||
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
|
||||
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
||||
std::string kernel_name = "colGeneral";
|
||||
if (is_out_64b_int) {
|
||||
kernel_name += "NoAtomics";
|
||||
}
|
||||
kernel_name += "_reduce_" + op_name + type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kernel_name, in, out);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -490,9 +498,11 @@ void strided_reduce_general_dispatch(
|
||||
}
|
||||
ndim = new_shape.size();
|
||||
|
||||
auto row_reduce_kernel = d.get_kernel(
|
||||
"row_reduce_general_no_atomics_" + op_name +
|
||||
type_to_name(intermediate));
|
||||
std::string kernel_name =
|
||||
"rowGeneralNoAtomics_reduce_" + op_name + type_to_name(intermediate);
|
||||
auto row_reduce_kernel =
|
||||
get_reduce_kernel(d, kernel_name, intermediate, out);
|
||||
|
||||
compute_encoder->setComputePipelineState(row_reduce_kernel);
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
@ -575,7 +585,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
|
||||
auto kernel = get_reduce_init_kernel(
|
||||
d, "i_reduce_" + op_name + type_to_name(out), out);
|
||||
size_t nthreads = out.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -28,9 +29,11 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in = arr_copy;
|
||||
}
|
||||
|
||||
bool contiguous = in.strides()[axis_] == 1;
|
||||
|
||||
std::ostringstream kname;
|
||||
if (in.strides()[axis_] == 1) {
|
||||
kname << "contiguous_scan_";
|
||||
kname << (contiguous ? "contig_" : "strided_");
|
||||
kname << "scan_";
|
||||
if (reverse_) {
|
||||
kname << "reverse_";
|
||||
}
|
||||
@ -50,8 +53,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
auto kernel = get_scan_kernel(d, kname.str(), reverse_, inclusive_, in, out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
if (contiguous) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@ -79,28 +83,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
kname << "strided_scan_";
|
||||
if (reverse_) {
|
||||
kname << "reverse_";
|
||||
}
|
||||
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||
switch (reduce_type_) {
|
||||
case Scan::Sum:
|
||||
kname << "sum_";
|
||||
break;
|
||||
case Scan::Prod:
|
||||
kname << "prod_";
|
||||
break;
|
||||
case Scan::Max:
|
||||
kname << "max_";
|
||||
break;
|
||||
case Scan::Min:
|
||||
kname << "min_";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
|
@ -1,15 +1,18 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
@ -52,18 +55,17 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const int simd_size = 32;
|
||||
const int n_reads = SOFTMAX_N_READS;
|
||||
const int looped_limit = SOFTMAX_LOOPED_LIMIT;
|
||||
std::string op_name = "softmax_";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "looped_";
|
||||
}
|
||||
|
||||
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
|
||||
kernel_name += "softmax_";
|
||||
if (in.dtype() != float32 && precise_) {
|
||||
op_name += "precise_";
|
||||
kernel_name += "precise_";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
kernel_name += type_to_name(out);
|
||||
|
||||
auto kernel = get_softmax_kernel(d, kernel_name, precise_, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -11,7 +12,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <bool ARGSORT>
|
||||
void single_block_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@ -19,7 +19,8 @@ void single_block_sort(
|
||||
array& out,
|
||||
int axis,
|
||||
int bn,
|
||||
int tn) {
|
||||
int tn,
|
||||
bool argsort) {
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
@ -46,19 +47,17 @@ void single_block_sort(
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
if (ARGSORT) {
|
||||
kname << "arg_";
|
||||
kname << (contiguous_write ? "c" : "nc");
|
||||
if (argsort) {
|
||||
kname << "arg";
|
||||
}
|
||||
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
|
||||
<< "_bn" << bn << "_tn" << tn;
|
||||
|
||||
if (!contiguous_write) {
|
||||
kname << "_nc";
|
||||
}
|
||||
kname << "_block_sort_" << type_to_name(in) << "_" << type_to_name(out)
|
||||
<< "_bn" << bn << "_tn" << tn;
|
||||
auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn);
|
||||
|
||||
// Prepare command encoder
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set inputs
|
||||
@ -81,7 +80,6 @@ void single_block_sort(
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
void multi_block_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@ -90,7 +88,8 @@ void multi_block_sort(
|
||||
int axis,
|
||||
int bn,
|
||||
int tn,
|
||||
int n_blocks) {
|
||||
int n_blocks,
|
||||
bool argsort) {
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
@ -136,10 +135,10 @@ void multi_block_sort(
|
||||
// Do blockwise sort
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
|
||||
kname << "sort_mbsort_" << type_to_name(dev_vals_0) << "_"
|
||||
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@ -175,10 +174,11 @@ void multi_block_sort(
|
||||
// Do partition
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
|
||||
kname << "partition_mbsort_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_output_array(block_partitions, 0);
|
||||
@ -196,10 +196,11 @@ void multi_block_sort(
|
||||
// Do merge
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
|
||||
kname << "merge_mbsort_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(block_partitions, 0);
|
||||
@ -219,7 +220,7 @@ void multi_block_sort(
|
||||
}
|
||||
|
||||
// Copy outputs with appropriate strides
|
||||
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
|
||||
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
|
||||
|
||||
if (axis == strided_out_arr.ndim() - 1) {
|
||||
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
|
||||
@ -252,13 +253,13 @@ void multi_block_sort(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
void gpu_merge_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis_) {
|
||||
int axis_,
|
||||
bool argsort) {
|
||||
// Get size info
|
||||
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
@ -284,9 +285,9 @@ void gpu_merge_sort(
|
||||
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
|
||||
|
||||
if (n_blocks > 1) {
|
||||
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
|
||||
return multi_block_sort(s, d, in, out, axis, bn, tn, n_blocks, argsort);
|
||||
} else {
|
||||
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
|
||||
return single_block_sort(s, d, in, out, axis, bn, tn, argsort);
|
||||
}
|
||||
}
|
||||
|
||||
@ -301,7 +302,7 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||
gpu_merge_sort(s, d, in, out, axis_, true);
|
||||
}
|
||||
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -313,7 +314,7 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||
gpu_merge_sort(s, d, in, out, axis_, false);
|
||||
}
|
||||
|
||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -326,7 +327,7 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||
gpu_merge_sort(s, d, in, out, axis_, true);
|
||||
}
|
||||
|
||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -339,7 +340,7 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||
gpu_merge_sort(s, d, in, out, axis_, false);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1565,8 +1565,9 @@ class Reduce : public UnaryPrimitive {
|
||||
switch (reduce_type_) {
|
||||
case And:
|
||||
os << "And";
|
||||
break;
|
||||
case Or:
|
||||
os << "And";
|
||||
os << "Or";
|
||||
break;
|
||||
case Sum:
|
||||
os << "Sum";
|
||||
@ -1581,7 +1582,6 @@ class Reduce : public UnaryPrimitive {
|
||||
os << "Max";
|
||||
break;
|
||||
}
|
||||
os << " Reduce";
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
@ -1647,7 +1647,6 @@ class Scan : public UnaryPrimitive {
|
||||
os << "Max";
|
||||
break;
|
||||
}
|
||||
os << " Reduce";
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user