mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
jagrit's commit files
This commit is contained in:
257
mlx/backend/metal/device.cpp
Normal file
257
mlx/backend/metal/device.cpp
Normal file
@@ -0,0 +1,257 @@
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#define NS_PRIVATE_IMPLEMENTATION
|
||||
#define CA_PRIVATE_IMPLEMENTATION
|
||||
#define MTL_PRIVATE_IMPLEMENTATION
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
static Device metal_device_;
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
static constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto load_device() {
|
||||
MTL::Device* device = MTL::CreateSystemDefaultDevice();
|
||||
if (!device) {
|
||||
throw std::runtime_error("Failed to load device");
|
||||
}
|
||||
return device;
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
MTL::Device* device,
|
||||
const char* path) {
|
||||
auto library = NS::String::string(path, NS::UTF8StringEncoding);
|
||||
NS::Error* error;
|
||||
auto lib = device->newLibrary(library, &error);
|
||||
|
||||
return std::make_pair(lib, error);
|
||||
}
|
||||
|
||||
MTL::Library* load_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name = "mlx",
|
||||
const char* lib_path = default_mtllib_path) {
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::string first_path = get_colocated_mtllib_path(lib_name);
|
||||
if (first_path.size() != 0) {
|
||||
auto [lib, error] = load_library_from_path(device, first_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
|
||||
// Couldn't find it so let's load it from default_mtllib_path
|
||||
{
|
||||
auto [lib, error] = load_library_from_path(device, lib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << error->localizedDescription()->utf8String() << "\n"
|
||||
<< "Failed to load device library from <" << lib_path << ">"
|
||||
<< " or <" << first_path << ">.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Device::Device()
|
||||
: pool_(NS::AutoreleasePool::alloc()->init()),
|
||||
device_(load_device()),
|
||||
library_map_({{"mlx", load_library(device_)}}) {}
|
||||
|
||||
Device::~Device() {
|
||||
for (auto& q : queue_map_) {
|
||||
q.second->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
for (auto& l : library_map_) {
|
||||
l.second->release();
|
||||
}
|
||||
device_->release();
|
||||
pool_->release();
|
||||
}
|
||||
|
||||
void Device::new_queue(int index) {
|
||||
// Multiple threads can ask the device for queues
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
}
|
||||
queue_map_.insert({index, q});
|
||||
}
|
||||
|
||||
int Device::get_command_buffer_ops(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return bit->second.first;
|
||||
}
|
||||
|
||||
void Device::increment_command_buffer_ops(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
bit->second.first++;
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::new_command_buffer(int index) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
}
|
||||
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
return buffer_map_.insert({index, {0, cb}}).first->second.second;
|
||||
}
|
||||
|
||||
void Device::commit_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
bit->second.second->commit();
|
||||
bit->second.second->release();
|
||||
buffer_map_.erase(bit);
|
||||
}
|
||||
|
||||
void Device::end_encoding(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit != encoder_map_.end()) {
|
||||
eit->second->endEncoding();
|
||||
eit->second->release();
|
||||
encoder_map_.erase(eit);
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
auto compute_encoder = cb->computeCommandEncoder();
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
compute_encoder->retain();
|
||||
eit = encoder_map_.insert({index, compute_encoder}).first;
|
||||
}
|
||||
return eit->second;
|
||||
}
|
||||
|
||||
MTL::ArgumentEncoder* Device::argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
|
||||
// NB array here is already autoreleased but the returned argument
|
||||
// encoder is owned by the caller and must be released/autoreleased
|
||||
NS::Array* arg_desc_arr = NS::Array::array(
|
||||
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
|
||||
return device_->newArgumentEncoder(arg_desc_arr);
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
}
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
std::string new_lib_path = lib_path_func(lib_name);
|
||||
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& name,
|
||||
const std::string& lib_name /* = "mlx" */) {
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Prepare new kernel
|
||||
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib;
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
mtl_lib = it->second;
|
||||
} else { // Look for metallib alongside library
|
||||
register_library(lib_name);
|
||||
mtl_lib = library_map_[lib_name];
|
||||
}
|
||||
|
||||
// Pull kernel from library
|
||||
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
|
||||
auto mtl_function = mtl_lib->newFunction(ns_name);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
NS::Error* error = nullptr;
|
||||
MTL::ComputePipelineState* kernel;
|
||||
if (mtl_function) {
|
||||
kernel = device_->newComputePipelineState(mtl_function, &error);
|
||||
mtl_function->release();
|
||||
}
|
||||
if (!mtl_function || !kernel) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({name, kernel});
|
||||
return kernel;
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device) {
|
||||
return metal_device_;
|
||||
}
|
||||
|
||||
NS::AutoreleasePool*& thread_autorelease_pool() {
|
||||
static thread_local NS::AutoreleasePool* p =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
return p;
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
thread_autorelease_pool();
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
10
mlx/backend/metal/fft.cpp
Normal file
10
mlx/backend/metal/fft.cpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
throw std::runtime_error("[FFT] NYI for Metal backend.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
83
mlx/backend/metal/kernels/CMakeLists.txt
Normal file
83
mlx/backend/metal/kernels/CMakeLists.txt
Normal file
@@ -0,0 +1,83 @@
|
||||
set(
|
||||
HEADERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
)
|
||||
|
||||
set(
|
||||
KERNELS
|
||||
"arange"
|
||||
"arg_reduce"
|
||||
"binary"
|
||||
"conv"
|
||||
"copy"
|
||||
"gemm"
|
||||
"gemv"
|
||||
"random"
|
||||
"reduce"
|
||||
"scan"
|
||||
"softmax"
|
||||
"sort"
|
||||
"unary"
|
||||
"indexing"
|
||||
)
|
||||
|
||||
function(build_kernel KERNEL)
|
||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||
set(HEADERS_PADDED ${HEADERS})
|
||||
if(${KERNEL} STREQUAL "gemm")
|
||||
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h)
|
||||
endif()
|
||||
if(${KERNEL} STREQUAL "conv")
|
||||
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
||||
-fno-fast-math
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${KERNEL}.air
|
||||
DEPENDS ${SRCFILE} ${HEADERS_PADDED}
|
||||
OUTPUT ${KERNEL}.air
|
||||
COMMENT "Building ${KERNEL}.air"
|
||||
VERBATIM
|
||||
)
|
||||
endfunction(build_kernel)
|
||||
|
||||
foreach(KERNEL ${KERNELS})
|
||||
build_kernel(${KERNEL})
|
||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
mlx-metallib
|
||||
DEPENDS
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
)
|
||||
|
||||
add_dependencies(
|
||||
mlx
|
||||
mlx-metallib
|
||||
)
|
||||
|
||||
# Install metallib
|
||||
include(GNUInstallDirs)
|
||||
|
||||
install(
|
||||
FILES ${MLX_METAL_PATH}/mlx.metallib
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
COMPONENT metallib
|
||||
)
|
17
mlx/backend/metal/kernels/conv_params.h
Normal file
17
mlx/backend/metal/kernels/conv_params.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
template <int NDIM>
|
||||
struct MLXConvParams {
|
||||
const int N; // Batch size
|
||||
const int C; // In channels
|
||||
const int O; // Out channels
|
||||
const int iS[NDIM]; // Input spatial dim
|
||||
const int wS[NDIM]; // Weight spatial dim
|
||||
const int oS[NDIM]; // Output spatial dim
|
||||
const int str[NDIM]; // Kernel strides
|
||||
const int pad[NDIM]; // Input padding
|
||||
const int dil[NDIM]; // Kernel dilation
|
||||
const size_t in_strides[NDIM + 2]; // In strides
|
||||
const size_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const size_t out_strides[NDIM + 2]; // Out strides
|
||||
};
|
253
mlx/backend/metal/kernels/indexing.metal
Normal file
253
mlx/backend/metal/kernels/indexing.metal
Normal file
@@ -0,0 +1,253 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_texture>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<device IdxT*, NIDX> buffers [[id(0)]];
|
||||
device int* shapes [[id(NIDX + 1)]];
|
||||
device size_t* strides [[id(NIDX + 2)]];
|
||||
const int ndim [[id(NIDX + 3)]];
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, int NIDX>
|
||||
[[kernel]] void gather(
|
||||
const device T *src [[buffer(0)]],
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
const device int *src_shape [[buffer(3)]],
|
||||
const device size_t *src_strides [[buffer(4)]],
|
||||
const device size_t& src_ndim [[buffer(5)]],
|
||||
const device int *slice_sizes [[buffer(6)]],
|
||||
const device size_t& slice_size [[buffer(7)]],
|
||||
const device int *axes [[buffer(8)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
|
||||
auto ind_idx = gid / slice_size;
|
||||
auto ind_offset = gid % slice_size;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
out[gid] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
#define instantiate_gather4(name, src_type, ind_type, nindex) \
|
||||
template [[host_name("gather" name "_" #nindex)]] \
|
||||
[[kernel]] void gather( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const device int *src_shape [[buffer(3)]], \
|
||||
const device size_t *src_strides [[buffer(4)]], \
|
||||
const device size_t& src_ndim [[buffer(5)]], \
|
||||
const device int *slice_sizes [[buffer(6)]], \
|
||||
const device size_t& slice_size [[buffer(7)]], \
|
||||
const device int* axes [[buffer(8)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10)
|
||||
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
[[kernel]] void scatter(
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const device int *upd_shape [[buffer(3)]],
|
||||
const device size_t *upd_strides [[buffer(4)]],
|
||||
const device size_t& upd_ndim [[buffer(5)]],
|
||||
const device size_t& upd_size [[buffer(6)]],
|
||||
const device int *out_shape [[buffer(7)]],
|
||||
const device size_t *out_strides [[buffer(8)]],
|
||||
const device size_t& out_ndim [[buffer(9)]],
|
||||
const device int* axes [[buffer(10)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid / upd_size;
|
||||
auto ind_offset = gid % upd_size;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
||||
|
||||
op.atomic_update(out + out_idx + out_offset, updates[upd_idx]);
|
||||
}
|
||||
|
||||
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
|
||||
template [[host_name("scatter" name "_" #nindex)]] \
|
||||
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
|
||||
const device type *updates [[buffer(1)]], \
|
||||
device mlx_atomic<type> *out [[buffer(2)]], \
|
||||
const device int *upd_shape [[buffer(3)]], \
|
||||
const device size_t *upd_strides [[buffer(4)]], \
|
||||
const device size_t& upd_ndim [[buffer(5)]], \
|
||||
const device size_t& upd_size [[buffer(6)]], \
|
||||
const device int *out_shape [[buffer(7)]], \
|
||||
const device size_t *out_strides [[buffer(8)]], \
|
||||
const device size_t& out_ndim [[buffer(9)]], \
|
||||
const device int* axes [[buffer(10)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||
|
||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t)
|
||||
|
||||
// TODO uint64 and int64 unsupported
|
||||
instantiate_scatter_nd0(bool_, bool)
|
||||
instantiate_scatter_nd0(uint8, uint8_t)
|
||||
instantiate_scatter_nd0(uint16, uint16_t)
|
||||
instantiate_scatter_nd0(uint32, uint32_t)
|
||||
instantiate_scatter_nd0(int8, int8_t)
|
||||
instantiate_scatter_nd0(int16, int16_t)
|
||||
instantiate_scatter_nd0(int32, int32_t)
|
||||
instantiate_scatter_nd0(float16, half)
|
||||
instantiate_scatter_nd0(float32, float)
|
||||
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
||||
|
||||
instantiate_scatter(bool_, bool)
|
||||
instantiate_scatter(uint8, uint8_t)
|
||||
instantiate_scatter(uint16, uint16_t)
|
||||
instantiate_scatter(uint32, uint32_t)
|
||||
instantiate_scatter(int8, int8_t)
|
||||
instantiate_scatter(int16, int16_t)
|
||||
instantiate_scatter(int32, int32_t)
|
||||
instantiate_scatter(float16, half)
|
||||
instantiate_scatter(float32, float)
|
||||
instantiate_scatter(bfloat16, bfloat16_t)
|
174
mlx/backend/metal/kernels/reduce.h
Normal file
174
mlx/backend/metal/kernels/reduce.h
Normal file
@@ -0,0 +1,174 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
union bool4_or_uint {
|
||||
bool4 b;
|
||||
unsigned int i;
|
||||
};
|
||||
|
||||
struct None {
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
};
|
||||
|
||||
struct And {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_all(val);
|
||||
};
|
||||
|
||||
static constexpr constant bool init = true;
|
||||
|
||||
void atomic_update(
|
||||
device mlx_atomic<unsigned int>* out,
|
||||
bool val,
|
||||
int elem_idx,
|
||||
int offset = 0) {
|
||||
if (!val) {
|
||||
bool4_or_uint update;
|
||||
update.b = {true, true, true, true};
|
||||
update.b[elem_idx] = false;
|
||||
mlx_atomic_fetch_and_explicit(out, update.i, offset);
|
||||
}
|
||||
}
|
||||
|
||||
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
||||
if (!val) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Non atomic update
|
||||
void update(device bool* out, bool val) {
|
||||
*out &= val;
|
||||
}
|
||||
|
||||
// Operator
|
||||
bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Or {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_any(val);
|
||||
};
|
||||
|
||||
static constexpr constant bool init = false;
|
||||
|
||||
void atomic_update(
|
||||
device mlx_atomic<unsigned int>* out,
|
||||
bool val,
|
||||
int elem_idx,
|
||||
int offset = 0) {
|
||||
if (val) {
|
||||
bool4_or_uint update;
|
||||
update.b = {false, false, false, false};
|
||||
update.b[elem_idx] = true;
|
||||
mlx_atomic_fetch_or_explicit(out, update.i, offset);
|
||||
}
|
||||
}
|
||||
|
||||
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
||||
if (val) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Non atomic update
|
||||
void update(device bool* out, bool val) {
|
||||
*out |= val;
|
||||
}
|
||||
|
||||
// Operator
|
||||
bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_sum(val);
|
||||
};
|
||||
|
||||
static constexpr constant U init = U(0);
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_fetch_add_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_product(val);
|
||||
};
|
||||
|
||||
static constexpr constant U init = U(1);
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_fetch_mul_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct Min {
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_min(val);
|
||||
};
|
||||
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_fetch_min_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct Max {
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_max(val);
|
||||
};
|
||||
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_fetch_max_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
};
|
492
mlx/backend/metal/kernels/scan.metal
Normal file
492
mlx/backend/metal/kernels/scan.metal
Normal file
@@ -0,0 +1,492 @@
|
||||
#include <metal_math>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename U>
|
||||
struct CumSum {
|
||||
static constexpr constant U init = static_cast<U>(0);
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
return simd_prefix_inclusive_sum(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
return simd_prefix_exclusive_sum(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumProd {
|
||||
static constexpr constant U init = static_cast<U>(1.0f);
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
return simd_prefix_inclusive_product(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
return simd_prefix_exclusive_product(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CumProd<bool> {
|
||||
static constexpr constant bool init = true;
|
||||
|
||||
template <typename T>
|
||||
bool operator()(bool a, T b) {
|
||||
return a & static_cast<bool>(b);
|
||||
}
|
||||
|
||||
bool simd_scan(bool x) {
|
||||
for (int i=1; i<=16; i*=2) {
|
||||
bool other = simd_shuffle_up(x, i);
|
||||
x &= other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
bool simd_exclusive_scan(bool x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumMax {
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return (a >= b) ? a : b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i=1; i<=16; i*=2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x >= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumMin {
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return (a <= b) ? a : b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i=1; i<=16; i*=2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x <= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_unsafe(U values[N_READS], const device T * input) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[N_READS-i-1] = input[i];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = input[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = (start + i < total) ? input[i] : init;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_unsafe(U values[N_READS], device U * out) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = values[N_READS-i-1];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_safe(U values[N_READS], device U * out, int start, int total) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if (start + N_READS - i - 1 < total) {
|
||||
out[i] = values[N_READS-i-1];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if (start + i < total) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
||||
[[kernel]] void contiguous_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t & axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
|
||||
// Position the pointers
|
||||
in += (gid / lsize) * axis_size;
|
||||
out += (gid / lsize) * axis_size;
|
||||
|
||||
// Compute the number of simd_groups
|
||||
uint simd_groups = lsize / simd_size;
|
||||
|
||||
// Allocate memory
|
||||
U prefix = Op::init;
|
||||
U values[N_READS];
|
||||
threadgroup U simdgroup_sums[32];
|
||||
|
||||
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
|
||||
// Read block
|
||||
// Compute inclusive scan of the block
|
||||
// Compute inclusive scan per thread
|
||||
// Compute exclusive scan of thread sums in simdgroup
|
||||
// Write simdgroup sums in SM
|
||||
// Compute exclusive scan of simdgroup sums
|
||||
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
|
||||
// Write block
|
||||
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||
// Compute the block offset
|
||||
uint offset = r*lsize*N_READS + lid*N_READS;
|
||||
|
||||
// Read the values
|
||||
if (reverse) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
|
||||
}
|
||||
} else {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute an inclusive scan per thread
|
||||
for (int i=1; i<N_READS; i++) {
|
||||
values[i] = op(values[i], values[i-1]);
|
||||
}
|
||||
|
||||
// Compute exclusive scan of thread sums
|
||||
U prev_thread = op.simd_exclusive_scan(values[N_READS-1]);
|
||||
|
||||
// Write simdgroup_sums to SM
|
||||
if (simd_lane_id == simd_size - 1) {
|
||||
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Compute exclusive scan of simdgroup_sums
|
||||
if (simd_group_id == 0) {
|
||||
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
|
||||
simdgroup_sums[simd_lane_id] = prev_simdgroup;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Compute the output
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = op(values[i], prefix);
|
||||
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
|
||||
values[i] = op(values[i], prev_thread);
|
||||
}
|
||||
|
||||
// Write the values
|
||||
if (reverse) {
|
||||
if (inclusive) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
out[axis_size-1] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
out[0] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Share the prefix
|
||||
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
||||
simdgroup_sums[0] = values[N_READS-1];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
prefix = simdgroup_sums[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
||||
[[kernel]] void strided_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t & axis_size [[buffer(2)]],
|
||||
const constant size_t & stride [[buffer(3)]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]) {
|
||||
Op op;
|
||||
|
||||
// Allocate memory
|
||||
threadgroup U read_buffer[N_READS*32*32 + N_READS*32];
|
||||
U values[N_READS];
|
||||
U prefix[N_READS];
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
prefix[i] = Op::init;
|
||||
}
|
||||
|
||||
// Compute offsets
|
||||
int offset = gid.y * axis_size * stride;
|
||||
int global_index_x = gid.x * lsize.y * N_READS;
|
||||
|
||||
for (uint j=0; j<axis_size; j+=simd_size) {
|
||||
// Calculate the indices for the current thread
|
||||
uint index_y = j + lid.y;
|
||||
uint check_index_y = index_y;
|
||||
uint index_x = global_index_x + lid.x * N_READS;
|
||||
if (reverse) {
|
||||
index_y = axis_size - 1 - index_y;
|
||||
}
|
||||
|
||||
// Read in SM
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
||||
} else {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Read strided into registers
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||
}
|
||||
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Perform the scan
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = op.simd_scan(values[i]);
|
||||
values[i] = op(values[i], prefix[i]);
|
||||
prefix[i] = simd_shuffle(values[i], simd_size-1);
|
||||
}
|
||||
|
||||
// Write to SM
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write to device memory
|
||||
if (!inclusive) {
|
||||
if (check_index_y == 0) {
|
||||
if ((index_x + N_READS) < stride) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if ((index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (reverse) {
|
||||
index_y -= 1;
|
||||
check_index_y += 1;
|
||||
} else {
|
||||
index_y += 1;
|
||||
check_index_y += 1;
|
||||
}
|
||||
}
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("contiguous_scan_" #name)]] \
|
||||
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t & axis_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("strided_scan_" #name)]] \
|
||||
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t & axis_size [[buffer(2)]], \
|
||||
const constant size_t & stride [[buffer(3)]], \
|
||||
uint2 gid [[thread_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
|
||||
|
||||
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
||||
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||
instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \
|
||||
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
|
||||
|
||||
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
|
||||
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
|
||||
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
|
||||
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
|
||||
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
|
||||
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
|
||||
//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
|
||||
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||
//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
|
88
mlx/backend/metal/metal.cpp
Normal file
88
mlx/backend/metal/metal.cpp
Normal file
@@ -0,0 +1,88 @@
|
||||
#include <cstdlib>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
int max_ops_per_buffer() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||
return atoi(buff_str);
|
||||
} else {
|
||||
return 10;
|
||||
}
|
||||
};
|
||||
static int max_ops_per_buffer_ = get_val();
|
||||
return max_ops_per_buffer_;
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
|
||||
MTL::CommandBuffer* increment_command_buffer(Stream s) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (command_buffer == nullptr ||
|
||||
d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
||||
if (command_buffer != nullptr) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); });
|
||||
d.commit_command_buffer(s.index);
|
||||
}
|
||||
command_buffer = d.new_command_buffer(s.index);
|
||||
}
|
||||
d.increment_command_buffer_ops(s.index);
|
||||
return command_buffer;
|
||||
}
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
auto task =
|
||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr, p = std::move(p)](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
// Signal this thread to clear the pool on a synchroniztion.
|
||||
scheduler::enqueue(s, []() {
|
||||
thread_autorelease_pool()->release();
|
||||
thread_autorelease_pool() =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
});
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
return task;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
28
mlx/backend/metal/metal.h
Normal file
28
mlx/backend/metal/metal.h
Normal file
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
constexpr bool is_available() {
|
||||
#ifdef _METAL_
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
void new_stream(Stream stream);
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph);
|
||||
|
||||
} // namespace mlx::core::metal
|
82
mlx/backend/metal/softmax.cpp
Normal file
82
mlx/backend/metal/softmax.cpp
Normal file
@@ -0,0 +1,82 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
throw std::runtime_error(
|
||||
"[softmax] Does not support non-floating point types.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
if (x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in = check_input(inputs[0]);
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = SOFTMAX_N_READS;
|
||||
const int looped_limit = SOFTMAX_LOOPED_LIMIT;
|
||||
std::string op_name = "softmax_";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "looped_";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
336
mlx/backend/metal/sort.cpp
Normal file
336
mlx/backend/metal/sort.cpp
Normal file
@@ -0,0 +1,336 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <bool ARGSORT>
|
||||
void single_block_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
int bn,
|
||||
int tn) {
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> nc_str = in.strides();
|
||||
nc_str.erase(nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
int stride_sorted_axis = in.strides()[axis];
|
||||
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
|
||||
|
||||
// Check if remaining strides are contiguous
|
||||
bool contiguous_write = true;
|
||||
if (axis != in.ndim() - 1 && axis != 0) {
|
||||
for (int i = 0; i < nc_str.size() - 1; ++i) {
|
||||
size_t expected = nc_str[i + 1] * nc_str[i + 1];
|
||||
contiguous_write &= (nc_str[i] == expected);
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
if (ARGSORT) {
|
||||
kname << "arg_";
|
||||
}
|
||||
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
|
||||
<< "_bn" << bn << "_tn" << tn;
|
||||
|
||||
if (!contiguous_write) {
|
||||
kname << "_nc";
|
||||
}
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set inputs
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
||||
|
||||
if (contiguous_write) {
|
||||
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
void multi_block_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
int bn,
|
||||
int tn,
|
||||
int n_blocks) {
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> nc_str = in.strides();
|
||||
nc_str.erase(nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
int stride_sorted_axis = in.strides()[axis];
|
||||
|
||||
// Make temporary copies
|
||||
array dev_vals_0({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
|
||||
array dev_vals_1({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
|
||||
|
||||
array dev_idxs_0({n_rows, size_sorted_axis}, uint32, nullptr, {});
|
||||
array dev_idxs_1({n_rows, size_sorted_axis}, uint32, nullptr, {});
|
||||
|
||||
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
|
||||
|
||||
// Do allocations
|
||||
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
|
||||
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
|
||||
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
|
||||
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
|
||||
block_partitions.set_data(
|
||||
allocator::malloc_or_wait(block_partitions.nbytes()));
|
||||
|
||||
std::vector<array> copies = {
|
||||
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
// Do blockwise sort
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
|
||||
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_0, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_0, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merges
|
||||
bool ping = false;
|
||||
array dev_vals_in = dev_vals_0;
|
||||
array dev_idxs_in = dev_idxs_0;
|
||||
array dev_vals_out = dev_vals_1;
|
||||
array dev_idxs_out = dev_idxs_1;
|
||||
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
|
||||
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
||||
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
||||
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
|
||||
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
||||
ping = !ping;
|
||||
|
||||
// Do partiton
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merge
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
set_array_buffer(compute_encoder, dev_vals_out, 3);
|
||||
set_array_buffer(compute_encoder, dev_idxs_out, 4);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy outputs with appropriate strides
|
||||
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
|
||||
|
||||
if (axis == strided_out_arr.ndim() - 1) {
|
||||
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
|
||||
} else {
|
||||
std::vector<int> strided_out_shape = strided_out_arr.shape();
|
||||
std::vector<size_t> strided_out_str = strided_out_arr.strides();
|
||||
|
||||
int out_axis_shape = strided_out_shape[axis];
|
||||
int out_axis_str = strided_out_str[axis];
|
||||
|
||||
strided_out_shape.erase(strided_out_shape.begin() + axis);
|
||||
strided_out_str.erase(strided_out_str.begin() + axis);
|
||||
|
||||
strided_out_shape.push_back(out_axis_shape);
|
||||
strided_out_str.push_back(out_axis_str);
|
||||
|
||||
array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {});
|
||||
strided_out_slice.copy_shared_buffer(
|
||||
strided_out_arr,
|
||||
strided_out_str,
|
||||
strided_out_arr.flags(),
|
||||
strided_out_arr.size(),
|
||||
0);
|
||||
|
||||
copy_gpu_inplace(strided_out_slice, out, CopyType::General, s);
|
||||
}
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
void gpu_merge_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis_) {
|
||||
// Get size info
|
||||
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
|
||||
// Get kernel size
|
||||
int tn = 8;
|
||||
int bn = 128;
|
||||
int potential_bn = (size_sorted_axis + tn - 1) / tn;
|
||||
|
||||
if (potential_bn > 256) {
|
||||
bn = 512;
|
||||
} else if (potential_bn > 128) {
|
||||
bn = 256;
|
||||
} else {
|
||||
bn = 128;
|
||||
}
|
||||
|
||||
if (bn == 512 && size_of(in.dtype()) > 4) {
|
||||
bn = 256;
|
||||
}
|
||||
|
||||
int n_per_block = bn * tn;
|
||||
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
|
||||
|
||||
if (n_blocks > 1) {
|
||||
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
|
||||
} else {
|
||||
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||
}
|
||||
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||
}
|
||||
|
||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct arg partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||
}
|
||||
|
||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
Reference in New Issue
Block a user