More jitting (#1132)

* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
This commit is contained in:
Awni Hannun 2024-05-23 16:23:44 -07:00 committed by GitHub
parent 9401507336
commit 0189ab6ab6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 2377 additions and 1846 deletions

View File

@ -200,7 +200,7 @@ GGUF, you can do:
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which

View File

@ -35,6 +35,7 @@ make_jit_source(
utils
kernels/bf16.h
kernels/complex.h
kernels/defines.h
)
make_jit_source(
unary_ops
@ -44,7 +45,7 @@ make_jit_source(
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(
reduction
reduce_utils
kernels/atomic.h
kernels/reduction/ops.h
)
@ -57,11 +58,21 @@ if (MLX_METAL_JIT)
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
)
make_jit_source(arange)
make_jit_source(copy)
make_jit_source(unary)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
make_jit_source(sort)
make_jit_source(
reduce
kernels/reduction/reduce_all.h
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
)
else()
target_sources(
mlx

View File

@ -229,7 +229,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduction() << metal::scatter();
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =

View File

@ -0,0 +1,9 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@ -1,4 +1,4 @@
// Copyright © 2023-24 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@ -8,14 +8,19 @@ const char* utils();
const char* binary_ops();
const char* unary_ops();
const char* ternary_ops();
const char* reduction();
const char* reduce_utils();
const char* gather();
const char* scatter();
const char* arange();
const char* unary();
const char* binary();
const char* binary_two();
const char* copy();
const char* ternary();
const char* scan();
const char* softmax();
const char* sort();
const char* reduce();
} // namespace mlx::core::metal

View File

@ -0,0 +1,168 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view reduce_init_kernels = R"(
[[kernel]] void {0}(
device {1}* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {{
out[tid] = {2}<{1}>::init;
}}
)";
constexpr std::string_view reduce_kernels = R"(
template [[host_name("all_{0}")]] [[kernel]] void
all_reduce<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("colGeneral_{0}")]] [[kernel]] void
col_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
row_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";
constexpr std::string_view reduce_non_atomic_kernels = R"(
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]);
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@ -0,0 +1,26 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view scan_kernels = R"(
template [[host_name("contig_{0}")]] [[kernel]] void
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("strided_{0}")]] [[kernel]] void
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]);
)";

View File

@ -0,0 +1,23 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@ -0,0 +1,81 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view block_sort_kernels = R"(
template [[host_name("carg_{0}")]] [[kernel]] void
block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("ncarg_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("c_{0}")]] [[kernel]] void
block_sort<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("nc_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view multiblock_sort_kernels = R"(
template [[host_name("sort_{0}")]] [[kernel]] void
mb_block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {1}* out_vals [[buffer(1)]],
device {2}* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("partition_{0}")]] [[kernel]] void
mb_block_partition<{1}, {2}, true, {3}, {4}>(
device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals [[buffer(1)]],
const device {2}* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]);
template [[host_name("merge_{0}")]] [[kernel]] void
mb_block_merge<{1}, {2}, true, {3}, {4}>(
const device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals_in [[buffer(1)]],
const device {2}* dev_idxs_in [[buffer(2)]],
device {1}* dev_vals_out [[buffer(3)]],
device {2}* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";

View File

@ -3,10 +3,15 @@
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/ternary.h"
#include "mlx/backend/metal/jit/unary.h"
#include "mlx/backend/metal/kernels.h"
@ -20,6 +25,22 @@ std::string op_name(const array& arr) {
return op_t.str();
}
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source
<< metal::utils() << metal::arange()
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
@ -121,4 +142,138 @@ MTL::ComputePipelineState* get_copy_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
bool precise,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
softmax_kernels,
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
bool reverse,
bool inclusive,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
scan_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out),
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
block_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
bn,
tn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_mb_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& idx,
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
multiblock_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
bn,
tn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< fmt::format(
reduce_init_kernels,
kernel_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(kernel_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
<< fmt::format(
non_atomic ? reduce_non_atomic_kernels
: reduce_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core

View File

@ -5,6 +5,11 @@
namespace mlx::core {
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
@ -33,4 +38,45 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& in,
const array& out);
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
bool precise,
const array& out);
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
bool reverse,
bool inclusive,
const array& in,
const array& out);
MTL::ComputePipelineState* get_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
int bn,
int tn);
MTL::ComputePipelineState* get_mb_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& idx,
int bn,
int tn);
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
} // namespace mlx::core

View File

@ -8,9 +8,9 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
)
set(
KERNELS
"arange"
"arg_reduce"
"conv"
"fft"
@ -20,31 +20,42 @@ set(
"rms_norm"
"layer_norm"
"rope"
"scan"
"scaled_dot_product_attention"
"softmax"
"sort"
)
if (NOT MLX_METAL_JIT)
set(
KERNELS
${KERNELS}
"arange"
"binary"
"binary_two"
"unary"
"ternary"
"copy"
"softmax"
"sort"
"scan"
"reduce"
)
set(
HEADERS
${HEADERS}
arange.h
unary_ops.h
unary.h
binary_ops.h
binary.h
ternary.h
copy.h
softmax.h
sort.h
scan.h
reduction/ops.h
reduction/reduce_init.h
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h
)
endif()
@ -87,15 +98,6 @@ foreach(KERNEL ${STEEL_KERNELS})
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal)
file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h)
foreach(KERNEL ${REDUCE_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib

View File

@ -0,0 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
template <typename T>
[[kernel]] void arange(
constant const T& start,
constant const T& step,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = start + index * step;
}

View File

@ -1,15 +1,8 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
template <typename T>
[[kernel]] void arange(
constant const T& start,
constant const T& step,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = start + index * step;
}
#include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
@ -18,7 +11,6 @@ template <typename T>
device type* out, \
uint index [[thread_position_in_grid]]);
// clang-format off
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t)

View File

@ -2,11 +2,8 @@
#pragma once
#ifndef MLX_METAL_JIT
#include <metal_atomic>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#endif
using namespace metal;

View File

@ -2,7 +2,7 @@
#pragma once
#ifdef __METAL__
#if defined __METAL__ || defined MLX_METAL_JIT
#define MTL_CONST constant
#else
#define MTL_CONST
@ -11,6 +11,5 @@
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;

View File

@ -0,0 +1,4 @@
#pragma once
#include "mlx/backend/metal/kernels/reduction/reduce_all.h"
#include "mlx/backend/metal/kernels/reduction/reduce_col.h"
#include "mlx/backend/metal/kernels/reduction/reduce_row.h"

View File

@ -0,0 +1,293 @@
// Copyright © 2024 Apple Inc.
#include <metal_atomic>
#include <metal_simdgroup>
// clang-format off
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
#include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_reduce_helper_floats(inst_f, name, op) \
inst_f(name, float16, half, op) \
inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_reduce_helper_uints(inst_f, name, op) \
inst_f(name, uint8, uint8_t, op) \
inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \
inst_f(name, int8, int8_t, op) \
inst_f(name, int16, int16_t, op) \
inst_f(name, int32, int32_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) \
inst_f(name, uint64, uint64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) \
type_f(inst_f, prod, Prod) \
type_f(inst_f, min_, Min) \
type_f(inst_f, max_, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint64, uint64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i_reduce_" #name)]] [[kernel]] void \
init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
#define instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template \
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
row_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template \
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
// clang-format on

View File

@ -1,32 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Reduce init
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And)
instantiate_init_reduce(orbool_, bool, Or) // clang-format on

View File

@ -5,11 +5,7 @@
#include <metal_atomic>
#include <metal_simdgroup>
#ifndef MLX_METAL_JIT
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#endif
static constant constexpr const uint8_t simd_size = 32;
union bool4_or_uint {
bool4 b;
@ -23,6 +19,7 @@ struct None {
}
};
template <typename U = bool>
struct And {
bool simd_reduce(bool val) {
return simd_all(val);
@ -60,6 +57,7 @@ struct And {
}
};
template <typename U = bool>
struct Or {
bool simd_reduce(bool val) {
return simd_any(val);

View File

@ -1,11 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// All reduce helper
///////////////////////////////////////////////////////////////////////////////
@ -139,50 +133,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@ -1,11 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Small column reduce kernel
///////////////////////////////////////////////////////////////////////////////
@ -52,23 +46,6 @@ template <typename T, typename U, typename Op>
out[out_idx] = total_val;
}
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce helper
///////////////////////////////////////////////////////////////////////////////
@ -186,64 +163,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
}
}
}
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
// clang-format off
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on

View File

@ -0,0 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}

View File

@ -1,74 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
// clang-format off
#define instantiate_reduce_helper_floats(inst_f, name, op) \
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_reduce_helper_uints(inst_f, name, op) \
inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \
inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
inst_f(name, int32, int32_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
// clang-format on

View File

@ -1,11 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Small row reductions
///////////////////////////////////////////////////////////////////////////////
@ -123,33 +117,6 @@ template <typename T, typename U, typename Op>
}
}
#define instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \
row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \
row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Large row reductions
///////////////////////////////////////////////////////////////////////////////
@ -318,61 +285,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) template \
[[host_name("row_reduce_general_" #name)]] [[kernel]] void \
row_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) template \
[[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \
row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@ -1,14 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
static constant constexpr const uint8_t simd_size = 32;

View File

@ -0,0 +1,440 @@
// Copyright © 2023-2024 Apple Inc.
template <typename U>
struct CumSum {
static constexpr constant U init = static_cast<U>(0);
template <typename T>
U operator()(U a, T b) {
return a + b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_sum(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_sum(x);
}
};
template <typename U>
struct CumProd {
static constexpr constant U init = static_cast<U>(1.0f);
template <typename T>
U operator()(U a, T b) {
return a * b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_product(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_product(x);
}
};
template <>
struct CumProd<bool> {
static constexpr constant bool init = true;
template <typename T>
bool operator()(bool a, T b) {
return a & static_cast<bool>(b);
}
bool simd_scan(bool x) {
for (int i = 1; i <= 16; i *= 2) {
bool other = simd_shuffle_up(x, i);
x &= other;
}
return x;
}
bool simd_exclusive_scan(bool x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMax {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return (a >= b) ? a : b;
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x >= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMin {
static constexpr constant U init = Limits<U>::max;
template <typename T>
U operator()(U a, T b) {
return (a <= b) ? a : b;
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x <= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] = input[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
values[i] = input[i];
}
}
}
template <typename T, typename U, int N_READS, bool reverse>
inline void load_safe(
U values[N_READS],
const device T* input,
int start,
int total,
U init) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] =
(start + N_READS - i - 1 < total) ? input[i] : init;
}
} else {
for (int i = 0; i < N_READS; i++) {
values[i] = (start + i < total) ? input[i] : init;
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_unsafe(U values[N_READS], device U* out) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
out[i] = values[N_READS - i - 1];
}
} else {
for (int i = 0; i < N_READS; i++) {
out[i] = values[i];
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
if (start + N_READS - i - 1 < total) {
out[i] = values[N_READS - i - 1];
}
}
} else {
for (int i = 0; i < N_READS; i++) {
if (start + i < total) {
out[i] = values[i];
}
}
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void contiguous_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Position the pointers
in += (gid / lsize) * axis_size;
out += (gid / lsize) * axis_size;
// Compute the number of simd_groups
uint simd_groups = lsize / simd_size;
// Allocate memory
U prefix = Op::init;
U values[N_READS];
threadgroup U simdgroup_sums[32];
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
// N_READS*lsize)
// Read block
// Compute inclusive scan of the block
// Compute inclusive scan per thread
// Compute exclusive scan of thread sums in simdgroup
// Write simdgroup sums in SM
// Compute exclusive scan of simdgroup sums
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
// value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Compute the block offset
uint offset = r * lsize * N_READS + lid * N_READS;
// Read the values
if (reverse) {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(
values, in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(
values,
in + axis_size - offset - N_READS,
offset,
axis_size,
Op::init);
}
} else {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
} else {
load_safe<T, U, N_READS, reverse>(
values, in + offset, offset, axis_size, Op::init);
}
}
// Compute an inclusive scan per thread
for (int i = 1; i < N_READS; i++) {
values[i] = op(values[i], values[i - 1]);
}
// Compute exclusive scan of thread sums
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
// Write simdgroup_sums to SM
if (simd_lane_id == simd_size - 1) {
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute exclusive scan of simdgroup_sums
if (simd_group_id == 0) {
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
simdgroup_sums[simd_lane_id] = prev_simdgroup;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute the output
for (int i = 0; i < N_READS; i++) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
values[i] = op(values[i], prev_thread);
}
// Write the values
if (reverse) {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS);
} else {
write_safe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[axis_size - 1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - 1 - N_READS);
} else {
write_safe<U, N_READS, reverse>(
values,
out + axis_size - offset - 1 - N_READS,
offset + 1,
axis_size);
}
}
} else {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset);
} else {
write_safe<U, N_READS, reverse>(
values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[0] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
} else {
write_safe<U, N_READS, reverse>(
values, out + offset + 1, offset + 1, axis_size);
}
}
}
// Share the prefix
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
simdgroup_sums[0] = values[N_READS - 1];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
prefix = simdgroup_sums[0];
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void strided_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]) {
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
U values[N_READS];
U prefix[N_READS];
for (int i = 0; i < N_READS; i++) {
prefix[i] = Op::init;
}
// Compute offsets
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
for (uint j = 0; j < axis_size; j += simd_size) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint check_index_y = index_y;
uint index_x = global_index_x + lid.x * N_READS;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i = 0; i < N_READS; i++) {
values[i] =
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
}
// Do we need the following barrier? Shouldn't all simd threads execute
// simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i = 0; i < N_READS; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size - 1);
}
// Write to SM
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write to device memory
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
}
}
}
if (reverse) {
index_y -= 1;
check_index_y += 1;
} else {
index_y += 1;
check_index_y += 1;
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
}
}
}
}

View File

@ -1,455 +1,19 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_math>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
// clang-format off
using namespace metal;
template <typename U>
struct CumSum {
static constexpr constant U init = static_cast<U>(0);
template <typename T>
U operator()(U a, T b) {
return a + b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_sum(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_sum(x);
}
};
template <typename U>
struct CumProd {
static constexpr constant U init = static_cast<U>(1.0f);
template <typename T>
U operator()(U a, T b) {
return a * b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_product(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_product(x);
}
};
template <>
struct CumProd<bool> {
static constexpr constant bool init = true;
template <typename T>
bool operator()(bool a, T b) {
return a & static_cast<bool>(b);
}
bool simd_scan(bool x) {
for (int i = 1; i <= 16; i *= 2) {
bool other = simd_shuffle_up(x, i);
x &= other;
}
return x;
}
bool simd_exclusive_scan(bool x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMax {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return (a >= b) ? a : b;
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x >= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMin {
static constexpr constant U init = Limits<U>::max;
template <typename T>
U operator()(U a, T b) {
return (a <= b) ? a : b;
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x <= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] = input[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
values[i] = input[i];
}
}
}
template <typename T, typename U, int N_READS, bool reverse>
inline void load_safe(
U values[N_READS],
const device T* input,
int start,
int total,
U init) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] =
(start + N_READS - i - 1 < total) ? input[i] : init;
}
} else {
for (int i = 0; i < N_READS; i++) {
values[i] = (start + i < total) ? input[i] : init;
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_unsafe(U values[N_READS], device U* out) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
out[i] = values[N_READS - i - 1];
}
} else {
for (int i = 0; i < N_READS; i++) {
out[i] = values[i];
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
if (start + N_READS - i - 1 < total) {
out[i] = values[N_READS - i - 1];
}
}
} else {
for (int i = 0; i < N_READS; i++) {
if (start + i < total) {
out[i] = values[i];
}
}
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void contiguous_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Position the pointers
in += (gid / lsize) * axis_size;
out += (gid / lsize) * axis_size;
// Compute the number of simd_groups
uint simd_groups = lsize / simd_size;
// Allocate memory
U prefix = Op::init;
U values[N_READS];
threadgroup U simdgroup_sums[32];
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
// N_READS*lsize)
// Read block
// Compute inclusive scan of the block
// Compute inclusive scan per thread
// Compute exclusive scan of thread sums in simdgroup
// Write simdgroup sums in SM
// Compute exclusive scan of simdgroup sums
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
// value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Compute the block offset
uint offset = r * lsize * N_READS + lid * N_READS;
// Read the values
if (reverse) {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(
values, in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(
values,
in + axis_size - offset - N_READS,
offset,
axis_size,
Op::init);
}
} else {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
} else {
load_safe<T, U, N_READS, reverse>(
values, in + offset, offset, axis_size, Op::init);
}
}
// Compute an inclusive scan per thread
for (int i = 1; i < N_READS; i++) {
values[i] = op(values[i], values[i - 1]);
}
// Compute exclusive scan of thread sums
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
// Write simdgroup_sums to SM
if (simd_lane_id == simd_size - 1) {
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute exclusive scan of simdgroup_sums
if (simd_group_id == 0) {
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
simdgroup_sums[simd_lane_id] = prev_simdgroup;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute the output
for (int i = 0; i < N_READS; i++) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
values[i] = op(values[i], prev_thread);
}
// Write the values
if (reverse) {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS);
} else {
write_safe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[axis_size - 1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - 1 - N_READS);
} else {
write_safe<U, N_READS, reverse>(
values,
out + axis_size - offset - 1 - N_READS,
offset + 1,
axis_size);
}
}
} else {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset);
} else {
write_safe<U, N_READS, reverse>(
values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[0] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
} else {
write_safe<U, N_READS, reverse>(
values, out + offset + 1, offset + 1, axis_size);
}
}
}
// Share the prefix
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
simdgroup_sums[0] = values[N_READS - 1];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
prefix = simdgroup_sums[0];
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void strided_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]) {
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
U values[N_READS];
U prefix[N_READS];
for (int i = 0; i < N_READS; i++) {
prefix[i] = Op::init;
}
// Compute offsets
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
for (uint j = 0; j < axis_size; j += simd_size) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint check_index_y = index_y;
uint index_x = global_index_x + lid.x * N_READS;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i = 0; i < N_READS; i++) {
values[i] =
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
}
// Do we need the following barrier? Shouldn't all simd threads execute
// simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i = 0; i < N_READS; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size - 1);
}
// Write to SM
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write to device memory
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
}
}
}
if (reverse) {
index_y -= 1;
check_index_y += 1;
} else {
index_y += 1;
check_index_y += 1;
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
}
}
}
}
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/scan.h"
#define instantiate_contiguous_scan( \
name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \
template [[host_name("contig_scan_" #name)]] [[kernel]] void \
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
@ -474,7 +38,6 @@ template <
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
// clang-format off
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
@ -483,9 +46,8 @@ template <
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
// clang-format off
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
@ -537,4 +99,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on

View File

@ -0,0 +1,190 @@
// Copyright © 2023-2024 Apple Inc.
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
: Limits<AccT>::finite_min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
AccT exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
local_normalizer[0] = normalizer;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[i] = T(ld[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = T(ld[i] * normalizer);
}
}
}
}
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += softmax_exp(vals[i] - maxval);
}
}
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
// lsize) parts. We need to combine them.
// 1. We start by finding the max across simd groups
// 2. We then change the partial normalizers to account for a possible
// change in max
// 3. We sum all normalizers
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= softmax_exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
// Now the normalizer and max value is correct for each simdgroup. We write
// them shared memory and combine them.
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= softmax_exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
normalizer = 1 / normalizer;
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] =
T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
}
}
}
}

View File

@ -1,205 +1,18 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
using namespace metal;
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
: Limits<AccT>::finite_min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
AccT exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
local_normalizer[0] = normalizer;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[i] = T(ld[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = T(ld[i] * normalizer);
}
}
}
}
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += softmax_exp(vals[i] - maxval);
}
}
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
// lsize) parts. We need to combine them.
// 1. We start by finding the max across simd groups
// 2. We then change the partial normalizers to account for a possible
// change in max
// 3. We sum all normalizers
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= softmax_exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
// Now the normalizer and max value is correct for each simdgroup. We write
// them shared memory and combine them.
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= softmax_exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
normalizer = 1 / normalizer;
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] =
T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
}
}
}
}
#include "mlx/backend/metal/kernels/softmax.h"
#define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
template [[host_name("block_softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
@ -208,7 +21,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
@ -220,7 +33,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
@ -229,7 +42,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
@ -240,7 +53,6 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)

View File

@ -0,0 +1,674 @@
// Copyright © 2023-2024 Apple Inc.
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
using namespace metal;
// Based on GPU merge sort algorithm at
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
///////////////////////////////////////////////////////////////////////////////
// Thread-level sort
///////////////////////////////////////////////////////////////////////////////
template <typename T>
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
T w = a;
a = b;
b = w;
}
template <typename T>
struct LessThan {
static constexpr constant T init = Limits<T>::max;
METAL_FUNC bool operator()(T a, T b) {
return a < b;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short N_PER_THREAD,
typename CompareOp>
struct ThreadSort {
static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
MLX_MTL_LOOP_UNROLL
for (short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if (op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Threadgroup-level sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp>
struct BlockMergeSort {
using thread_sort_t =
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
short A_sz,
short B_sz,
short sort_md) {
CompareOp op;
short A_st = max(0, sort_md - B_sz);
short A_ed = min(sort_md, A_sz);
while (A_st < A_ed) {
short md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
static METAL_FUNC void merge_step(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
const threadgroup idx_t* As_idx,
const threadgroup idx_t* Bs_idx,
short A_sz,
short B_sz,
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
short a_idx = 0;
short b_idx = 0;
for (int i = 0; i < N_PER_THREAD; ++i) {
auto a = As[a_idx];
auto b = Bs[b_idx];
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
vals[i] = pred ? b : a;
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
b_idx += short(pred);
a_idx += short(!pred);
}
}
static METAL_FUNC void sort(
threadgroup val_t* tgp_vals [[threadgroup(0)]],
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location
int idx = lid.x * N_PER_THREAD;
// Load from shared memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; ++i) {
thread_vals[i] = tgp_vals[idx + i];
if (ARG_SORT) {
thread_idxs[i] = tgp_idxs[idx + i];
}
}
// Per thread sort
if (idx < size_sorted_axis) {
thread_sort_t::sort(thread_vals, thread_idxs);
}
// Do merges using threadgroup memory
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
merge_threads *= 2) {
// Update threadgroup memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Find location in merge step
int merge_group = lid.x / merge_threads;
int merge_lane = lid.x % merge_threads;
int sort_sz = N_PER_THREAD * merge_threads;
int sort_st = N_PER_THREAD * merge_threads * merge_group;
// As = tgp_vals[A_st:A_ed] is sorted
// Bs = tgp_vals[B_st:B_ed] is sorted
int A_st = sort_st;
int A_ed = sort_st + sort_sz / 2;
int B_st = sort_st + sort_sz / 2;
int B_ed = sort_st + sort_sz;
const threadgroup val_t* As = tgp_vals + A_st;
const threadgroup val_t* Bs = tgp_vals + B_st;
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Find a partition of merge elements
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// of size N_PER_THREAD for each merge lane i
// C = [Ci] is sorted
int sort_md = N_PER_THREAD * merge_lane;
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
As += partition;
Bs += sort_md - partition;
A_sz -= partition;
B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx =
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx =
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
}
// Write out to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Kernel sort
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<T>>
struct KernelMergeSort {
using val_t = T;
using idx_t = uint;
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device T* inp,
device U* out,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
const constant int& stride_segment_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
: val_t(CompareOp::init);
if (ARG_SORT) {
tgp_idxs[i] = i;
}
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
}
}
}
};
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
nullptr,
tid,
lid);
}
}
constant constexpr const int zero_helper = 0;
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out += block_idx;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
nullptr,
tid,
lid);
}
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device val_t* inp,
device val_t* out_vals,
device idx_t* out_idxs,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
int base_idx = tid.x * N_PER_BLOCK;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
: val_t(CompareOp::init);
tgp_idxs[i] = idx;
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if (idx < size_sorted_axis) {
out_vals[idx] = tgp_vals[i];
out_idxs[idx] = tgp_idxs[i];
}
}
}
static METAL_FUNC int merge_partition(
const device val_t* As,
const device val_t* Bs,
int A_sz,
int B_sz,
int sort_md) {
CompareOp op;
int A_st = max(0, sort_md - B_sz);
int A_ed = min(sort_md, A_sz);
while (A_st < A_ed) {
int md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
const device val_t* inp [[buffer(0)]],
device val_t* out_vals [[buffer(1)]],
device idx_t* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out_vals += tid.y * size_sorted_axis;
out_idxs += tid.y * size_sorted_axis;
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_partition(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
block_partitions += tid.y * tgp_dims.x;
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
block_partitions[lid.x] = A_st + partition;
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]],
device val_t* dev_vals_out [[buffer(3)]],
device idx_t* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
using block_sort_t = typename sort_kernel::block_merge_sort_t;
block_partitions += tid.y * (num_tiles + 1);
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_out += tid.y * size_sorted_axis;
dev_idxs_out += tid.y * size_sorted_axis;
int block_idx = tid.x;
int merge_group = block_idx / merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
int A_st = block_partitions[block_idx + 0];
int A_ed = block_partitions[block_idx + 1];
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
int B_ed = min(
size_sorted_axis,
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
if ((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
B_ed = min(size_sorted_axis, sort_st + sort_sz);
}
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Load from global memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
if (idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
: dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
: dev_idxs_in[B_st + idx - A_sz];
} else {
thread_vals[i] = CompareOp::init;
thread_idxs[i] = 0;
}
}
// Write to shared memory
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
tgp_vals[idx] = thread_vals[i];
tgp_idxs[idx] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Merge
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
int A_st_local = block_sort_t::merge_partition(
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
int A_ed_local = A_sz;
int B_st_local = sort_md_local - A_st_local;
int B_ed_local = B_sz;
int A_sz_local = A_ed_local - A_st_local;
int B_sz_local = B_ed_local - B_st_local;
// Do merge
block_sort_t::merge_step(
tgp_vals + A_st_local,
tgp_vals + A_ed_local + B_st_local,
tgp_idxs + A_st_local,
tgp_idxs + A_ed_local + B_st_local,
A_sz_local,
B_sz_local,
thread_vals,
thread_idxs);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
int idx = lid.x * N_PER_THREAD;
tgp_vals[idx + i] = thread_vals[i];
tgp_idxs[idx + i] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if (idx < size_sorted_axis) {
dev_vals_out[idx] = tgp_vals[i];
dev_idxs_out[idx] = tgp_idxs[i];
}
}
}

View File

@ -1,392 +1,16 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
using namespace metal;
// Based on GPU merge sort algorithm at
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
///////////////////////////////////////////////////////////////////////////////
// Thread-level sort
///////////////////////////////////////////////////////////////////////////////
template <typename T>
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
T w = a;
a = b;
b = w;
}
template <typename T>
struct LessThan {
static constexpr constant T init = Limits<T>::max;
METAL_FUNC bool operator()(T a, T b) {
return a < b;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short N_PER_THREAD,
typename CompareOp>
struct ThreadSort {
static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
MLX_MTL_LOOP_UNROLL
for (short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if (op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Threadgroup-level sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp>
struct BlockMergeSort {
using thread_sort_t =
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
short A_sz,
short B_sz,
short sort_md) {
CompareOp op;
short A_st = max(0, sort_md - B_sz);
short A_ed = min(sort_md, A_sz);
while (A_st < A_ed) {
short md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
static METAL_FUNC void merge_step(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
const threadgroup idx_t* As_idx,
const threadgroup idx_t* Bs_idx,
short A_sz,
short B_sz,
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
short a_idx = 0;
short b_idx = 0;
for (int i = 0; i < N_PER_THREAD; ++i) {
auto a = As[a_idx];
auto b = Bs[b_idx];
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
vals[i] = pred ? b : a;
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
b_idx += short(pred);
a_idx += short(!pred);
}
}
static METAL_FUNC void sort(
threadgroup val_t* tgp_vals [[threadgroup(0)]],
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location
int idx = lid.x * N_PER_THREAD;
// Load from shared memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; ++i) {
thread_vals[i] = tgp_vals[idx + i];
if (ARG_SORT) {
thread_idxs[i] = tgp_idxs[idx + i];
}
}
// Per thread sort
if (idx < size_sorted_axis) {
thread_sort_t::sort(thread_vals, thread_idxs);
}
// Do merges using threadgroup memory
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
merge_threads *= 2) {
// Update threadgroup memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Find location in merge step
int merge_group = lid.x / merge_threads;
int merge_lane = lid.x % merge_threads;
int sort_sz = N_PER_THREAD * merge_threads;
int sort_st = N_PER_THREAD * merge_threads * merge_group;
// As = tgp_vals[A_st:A_ed] is sorted
// Bs = tgp_vals[B_st:B_ed] is sorted
int A_st = sort_st;
int A_ed = sort_st + sort_sz / 2;
int B_st = sort_st + sort_sz / 2;
int B_ed = sort_st + sort_sz;
const threadgroup val_t* As = tgp_vals + A_st;
const threadgroup val_t* Bs = tgp_vals + B_st;
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Find a partition of merge elements
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// of size N_PER_THREAD for each merge lane i
// C = [Ci] is sorted
int sort_md = N_PER_THREAD * merge_lane;
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
As += partition;
Bs += sort_md - partition;
A_sz -= partition;
B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx =
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx =
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
}
// Write out to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Kernel sort
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<T>>
struct KernelMergeSort {
using val_t = T;
using idx_t = uint;
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device T* inp,
device U* out,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
const constant int& stride_segment_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
: val_t(CompareOp::init);
if (ARG_SORT) {
tgp_idxs[i] = i;
}
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
}
}
}
};
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
nullptr,
tid,
lid);
}
}
constant constexpr const int zero_helper = 0;
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out += block_idx;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
nullptr,
tid,
lid);
}
}
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#include "mlx/backend/metal/kernels/sort.h"
#define instantiate_block_sort( \
name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \
template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
@ -396,8 +20,8 @@ template <
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
"_nc")]] [[kernel]] void \
template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
)]] [[kernel]] void \
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
@ -411,22 +35,20 @@ template <
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \
arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
arg_block_sort, itname, itype, uint32, uint32_t, true, bn, tn)
#define instantiate_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \
block_merge_sort, itname, itype, itname, itype, false, bn, tn)
_block_sort, itname, itype, itname, itype, false, bn, tn)
// clang-format off
#define instantiate_block_sort_tn(itname, itype, bn) \
instantiate_block_sort_base(itname, itype, bn, 8) \
instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on
instantiate_arg_block_sort_base(itname, itype, bn, 8)
// clang-format off
#define instantiate_block_sort_bn(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_bn(uint8, uint8_t)
instantiate_block_sort_bn(uint16, uint16_t)
@ -436,321 +58,18 @@ instantiate_block_sort_bn(int16, int16_t)
instantiate_block_sort_bn(int32, int32_t)
instantiate_block_sort_bn(float16, half)
instantiate_block_sort_bn(float32, float)
instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on
// clang-format off
instantiate_block_sort_bn(bfloat16, bfloat16_t)
#define instantiate_block_sort_long(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256)
instantiate_block_sort_long(uint64, uint64_t)
instantiate_block_sort_long(int64, int64_t) // clang-format on
///////////////////////////////////////////////////////////////////////////////
// Multi block merge sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device val_t* inp,
device val_t* out_vals,
device idx_t* out_idxs,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
int base_idx = tid.x * N_PER_BLOCK;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
: val_t(CompareOp::init);
tgp_idxs[i] = idx;
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if (idx < size_sorted_axis) {
out_vals[idx] = tgp_vals[i];
out_idxs[idx] = tgp_idxs[i];
}
}
}
static METAL_FUNC int merge_partition(
const device val_t* As,
const device val_t* Bs,
int A_sz,
int B_sz,
int sort_md) {
CompareOp op;
int A_st = max(0, sort_md - B_sz);
int A_ed = min(sort_md, A_sz);
while (A_st < A_ed) {
int md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
const device val_t* inp [[buffer(0)]],
device val_t* out_vals [[buffer(1)]],
device idx_t* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out_vals += tid.y * size_sorted_axis;
out_idxs += tid.y * size_sorted_axis;
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_partition(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
block_partitions += tid.y * tgp_dims.x;
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
block_partitions[lid.x] = A_st + partition;
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]],
device val_t* dev_vals_out [[buffer(3)]],
device idx_t* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
using block_sort_t = typename sort_kernel::block_merge_sort_t;
block_partitions += tid.y * (num_tiles + 1);
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_out += tid.y * size_sorted_axis;
dev_idxs_out += tid.y * size_sorted_axis;
int block_idx = tid.x;
int merge_group = block_idx / merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
int A_st = block_partitions[block_idx + 0];
int A_ed = block_partitions[block_idx + 1];
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
int B_ed = min(
size_sorted_axis,
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
if ((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
B_ed = min(size_sorted_axis, sort_st + sort_sz);
}
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Load from global memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
if (idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
: dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
: dev_idxs_in[B_st + idx - A_sz];
} else {
thread_vals[i] = CompareOp::init;
thread_idxs[i] = 0;
}
}
// Write to shared memory
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
tgp_vals[idx] = thread_vals[i];
tgp_idxs[idx] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Merge
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
int A_st_local = block_sort_t::merge_partition(
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
int A_ed_local = A_sz;
int B_st_local = sort_md_local - A_st_local;
int B_ed_local = B_sz;
int A_sz_local = A_ed_local - A_st_local;
int B_sz_local = B_ed_local - B_st_local;
// Do merge
block_sort_t::merge_step(
tgp_vals + A_st_local,
tgp_vals + A_ed_local + B_st_local,
tgp_idxs + A_st_local,
tgp_idxs + A_ed_local + B_st_local,
A_sz_local,
B_sz_local,
thread_vals,
thread_idxs);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
int idx = lid.x * N_PER_THREAD;
tgp_vals[idx + i] = thread_vals[i];
tgp_idxs[idx + i] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if (idx < size_sorted_axis) {
dev_vals_out[idx] = tgp_vals[i];
dev_idxs_out[idx] = tgp_idxs[i];
}
}
}
instantiate_block_sort_long(int64, int64_t)
#define instantiate_multi_block_sort( \
vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \
template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
@ -763,7 +82,7 @@ mb_block_merge(
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \
template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype * block_partitions [[buffer(0)]], \
@ -774,7 +93,7 @@ mb_block_merge(
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \
template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
@ -788,7 +107,6 @@ mb_block_merge(
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
@ -800,11 +118,10 @@ instantiate_multi_block_sort_base(int16, int16_t)
instantiate_multi_block_sort_base(int32, int32_t)
instantiate_multi_block_sort_base(float16, half)
instantiate_multi_block_sort_base(float32, float)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
// clang-format off
#define instantiate_multi_block_sort_long(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
instantiate_multi_block_sort_long(uint64, uint64_t)
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on

View File

@ -5,6 +5,7 @@
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/complex.h"
#include "mlx/backend/metal/kernels/defines.h"
typedef half float16_t;

View File

@ -5,6 +5,13 @@
namespace mlx::core {
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
@ -15,31 +22,84 @@ MTL::ComputePipelineState* get_unary_kernel(
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
bool,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
bool,
bool,
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&,
int,
int) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_mb_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&,
int,
int) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&) {
return d.get_kernel(kernel_name);
}

View File

@ -6,6 +6,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@ -27,7 +28,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
}
auto& s = stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel("arange" + type_to_name(out));
auto kernel = get_arange_kernel(d, "arange" + type_to_name(out), out);
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(

View File

@ -6,6 +6,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h"
@ -40,9 +41,12 @@ void all_reduce_dispatch(
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("all_reduce_" + op_name + type_to_name(in));
std::string kernel_name = "all";
if (is_out_64b_int) {
kernel_name += "NoAtomics";
}
kernel_name += "_reduce_" + op_name + type_to_name(in);
auto kernel = get_reduce_kernel(d, kernel_name, in, out);
compute_encoder->setComputePipelineState(kernel);
@ -158,18 +162,20 @@ void row_reduce_general_dispatch(
bool is_med = non_row_reductions * reduction_size <= 256;
is_out_64b_int &= !is_small && !is_med;
std::string small_desc = "_";
if (is_small) {
small_desc = "_small_";
std::string small_desc;
if (is_out_64b_int) {
small_desc = "NoAtomics";
} else if (is_small) {
small_desc = "Small";
} else if (is_med) {
small_desc = "_med_";
small_desc = "Med";
} else {
small_desc = "";
}
kname << "rowGeneral" << small_desc << "_reduce_" << op_name
<< type_to_name(in);
small_desc = is_out_64b_int ? "_no_atomics_" : small_desc;
kname << "row_reduce_general" << small_desc << op_name << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto kernel = get_reduce_kernel(d, kname.str(), in, out);
compute_encoder->setComputePipelineState(kernel);
// Get dispatch grid dims
@ -335,8 +341,8 @@ void strided_reduce_general_dispatch(
// Specialize for small dims
if (reduction_size * non_col_reductions < 16) {
// Select kernel
auto kernel =
d.get_kernel("col_reduce_small_" + op_name + type_to_name(in));
auto kernel = get_reduce_kernel(
d, "colSmall_reduce_" + op_name + type_to_name(in), in, out);
compute_encoder->setComputePipelineState(kernel);
// Select block dims
@ -373,10 +379,12 @@ void strided_reduce_general_dispatch(
// Select kernel
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
std::string kernel_name = "colGeneral";
if (is_out_64b_int) {
kernel_name += "NoAtomics";
}
kernel_name += "_reduce_" + op_name + type_to_name(in);
auto kernel = get_reduce_kernel(d, kernel_name, in, out);
compute_encoder->setComputePipelineState(kernel);
@ -490,9 +498,11 @@ void strided_reduce_general_dispatch(
}
ndim = new_shape.size();
auto row_reduce_kernel = d.get_kernel(
"row_reduce_general_no_atomics_" + op_name +
type_to_name(intermediate));
std::string kernel_name =
"rowGeneralNoAtomics_reduce_" + op_name + type_to_name(intermediate);
auto row_reduce_kernel =
get_reduce_kernel(d, kernel_name, intermediate, out);
compute_encoder->setComputePipelineState(row_reduce_kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
@ -575,7 +585,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
auto kernel = get_reduce_init_kernel(
d, "i_reduce_" + op_name + type_to_name(out), out);
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();

View File

@ -5,6 +5,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
@ -28,30 +29,33 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in = arr_copy;
}
std::ostringstream kname;
if (in.strides()[axis_] == 1) {
kname << "contiguous_scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
bool contiguous = in.strides()[axis_] == 1;
auto kernel = d.get_kernel(kname.str());
std::ostringstream kname;
kname << (contiguous ? "contig_" : "strided_");
kname << "scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(d, kname.str(), reverse_, inclusive_, in, out);
if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
@ -79,28 +83,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
kname << "strided_scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);

View File

@ -1,15 +1,18 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!issubdtype(out.dtype(), floating)) {
@ -52,18 +55,17 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
const int simd_size = 32;
const int n_reads = SOFTMAX_N_READS;
const int looped_limit = SOFTMAX_LOOPED_LIMIT;
std::string op_name = "softmax_";
if (axis_size > looped_limit) {
op_name += "looped_";
}
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
kernel_name += "softmax_";
if (in.dtype() != float32 && precise_) {
op_name += "precise_";
kernel_name += "precise_";
}
op_name += type_to_name(out);
kernel_name += type_to_name(out);
auto kernel = get_softmax_kernel(d, kernel_name, precise_, out);
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;

View File

@ -4,6 +4,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
@ -11,7 +12,6 @@ namespace mlx::core {
namespace {
template <bool ARGSORT>
void single_block_sort(
const Stream& s,
metal::Device& d,
@ -19,7 +19,8 @@ void single_block_sort(
array& out,
int axis,
int bn,
int tn) {
int tn,
bool argsort) {
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
@ -46,19 +47,17 @@ void single_block_sort(
// Prepare kernel name
std::ostringstream kname;
if (ARGSORT) {
kname << "arg_";
kname << (contiguous_write ? "c" : "nc");
if (argsort) {
kname << "arg";
}
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
<< "_bn" << bn << "_tn" << tn;
if (!contiguous_write) {
kname << "_nc";
}
kname << "_block_sort_" << type_to_name(in) << "_" << type_to_name(out)
<< "_bn" << bn << "_tn" << tn;
auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn);
// Prepare command encoder
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Set inputs
@ -81,7 +80,6 @@ void single_block_sort(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
template <bool ARGSORT>
void multi_block_sort(
const Stream& s,
metal::Device& d,
@ -90,7 +88,8 @@ void multi_block_sort(
int axis,
int bn,
int tn,
int n_blocks) {
int n_blocks,
bool argsort) {
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
@ -136,10 +135,10 @@ void multi_block_sort(
// Do blockwise sort
{
std::ostringstream kname;
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
kname << "sort_mbsort_" << type_to_name(dev_vals_0) << "_"
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
@ -175,10 +174,11 @@ void multi_block_sort(
// Do partition
{
std::ostringstream kname;
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
kname << "partition_mbsort_" << type_to_name(dev_vals_in) << "_"
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_output_array(block_partitions, 0);
@ -196,10 +196,11 @@ void multi_block_sort(
// Do merge
{
std::ostringstream kname;
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
kname << "merge_mbsort_" << type_to_name(dev_vals_in) << "_"
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(block_partitions, 0);
@ -219,7 +220,7 @@ void multi_block_sort(
}
// Copy outputs with appropriate strides
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
if (axis == strided_out_arr.ndim() - 1) {
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
@ -252,13 +253,13 @@ void multi_block_sort(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
template <bool ARGSORT>
void gpu_merge_sort(
const Stream& s,
metal::Device& d,
const array& in,
array& out,
int axis_) {
int axis_,
bool argsort) {
// Get size info
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
int size_sorted_axis = in.shape(axis);
@ -284,9 +285,9 @@ void gpu_merge_sort(
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
if (n_blocks > 1) {
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
return multi_block_sort(s, d, in, out, axis, bn, tn, n_blocks, argsort);
} else {
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
return single_block_sort(s, d, in, out, axis, bn, tn, argsort);
}
}
@ -301,7 +302,7 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<true>(s, d, in, out, axis_);
gpu_merge_sort(s, d, in, out, axis_, true);
}
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -313,7 +314,7 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<false>(s, d, in, out, axis_);
gpu_merge_sort(s, d, in, out, axis_, false);
}
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -326,7 +327,7 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<true>(s, d, in, out, axis_);
gpu_merge_sort(s, d, in, out, axis_, true);
}
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -339,7 +340,7 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<false>(s, d, in, out, axis_);
gpu_merge_sort(s, d, in, out, axis_, false);
}
} // namespace mlx::core

View File

@ -1565,8 +1565,9 @@ class Reduce : public UnaryPrimitive {
switch (reduce_type_) {
case And:
os << "And";
break;
case Or:
os << "And";
os << "Or";
break;
case Sum:
os << "Sum";
@ -1581,7 +1582,6 @@ class Reduce : public UnaryPrimitive {
os << "Max";
break;
}
os << " Reduce";
}
bool is_equivalent(const Primitive& other) const override;
@ -1647,7 +1647,6 @@ class Scan : public UnaryPrimitive {
os << "Max";
break;
}
os << " Reduce";
}
bool is_equivalent(const Primitive& other) const override;