jagrit's commit files

This commit is contained in:
Jagrit Digani
2023-11-29 10:52:08 -08:00
parent d1f86272a2
commit e6306cfee9
74 changed files with 15964 additions and 2 deletions

View File

@@ -0,0 +1,257 @@
#include <dlfcn.h>
#include <cstdlib>
#include <filesystem>
#include <sstream>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/mps/gemm.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
static Device metal_device_;
namespace {
// TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
static constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() {
MTL::Device* device = MTL::CreateSystemDefaultDevice();
if (!device) {
throw std::runtime_error("Failed to load device");
}
return device;
}
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
MTL::Device* device,
const char* path) {
auto library = NS::String::string(path, NS::UTF8StringEncoding);
NS::Error* error;
auto lib = device->newLibrary(library, &error);
return std::make_pair(lib, error);
}
MTL::Library* load_library(
MTL::Device* device,
const std::string& lib_name = "mlx",
const char* lib_path = default_mtllib_path) {
// Firstly, search for the metallib in the same path as this binary
std::string first_path = get_colocated_mtllib_path(lib_name);
if (first_path.size() != 0) {
auto [lib, error] = load_library_from_path(device, first_path.c_str());
if (lib) {
return lib;
}
}
// Couldn't find it so let's load it from default_mtllib_path
{
auto [lib, error] = load_library_from_path(device, lib_path);
if (!lib) {
std::ostringstream msg;
msg << error->localizedDescription()->utf8String() << "\n"
<< "Failed to load device library from <" << lib_path << ">"
<< " or <" << first_path << ">.";
throw std::runtime_error(msg.str());
}
return lib;
}
}
} // namespace
Device::Device()
: pool_(NS::AutoreleasePool::alloc()->init()),
device_(load_device()),
library_map_({{"mlx", load_library(device_)}}) {}
Device::~Device() {
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
}
device_->release();
pool_->release();
}
void Device::new_queue(int index) {
// Multiple threads can ask the device for queues
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
queue_map_.insert({index, q});
}
int Device::get_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
return bit->second.first;
}
void Device::increment_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
bit->second.first++;
}
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
}
MTL::CommandBuffer* Device::new_command_buffer(int index) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
return buffer_map_.insert({index, {0, cb}}).first->second.second;
}
void Device::commit_command_buffer(int index) {
auto bit = buffer_map_.find(index);
bit->second.second->commit();
bit->second.second->release();
buffer_map_.erase(bit);
}
void Device::end_encoding(int index) {
auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) {
eit->second->endEncoding();
eit->second->release();
encoder_map_.erase(eit);
}
}
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
auto compute_encoder = cb->computeCommandEncoder();
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_.insert({index, compute_encoder}).first;
}
return eit->second;
}
MTL::ArgumentEncoder* Device::argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
// NB array here is already autoreleased but the returned argument
// encoder is owned by the caller and must be released/autoreleased
NS::Array* arg_desc_arr = NS::Array::array(
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
return device_->newArgumentEncoder(arg_desc_arr);
}
void Device::register_library(
const std::string& lib_name,
const std::string& lib_path) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
library_map_.insert({lib_name, new_lib});
}
}
void Device::register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
std::string new_lib_path = lib_path_func(lib_name);
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
library_map_.insert({lib_name, new_lib});
}
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
}
// Prepare new kernel
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name);
mtl_lib = library_map_[lib_name];
}
// Pull kernel from library
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
auto mtl_function = mtl_lib->newFunction(ns_name);
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
MTL::ComputePipelineState* kernel;
if (mtl_function) {
kernel = device_->newComputePipelineState(mtl_function, &error);
mtl_function->release();
}
if (!mtl_function || !kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
// Add kernel to cache
kernel_map_.insert({name, kernel});
return kernel;
}
Device& device(mlx::core::Device) {
return metal_device_;
}
NS::AutoreleasePool*& thread_autorelease_pool() {
static thread_local NS::AutoreleasePool* p =
NS::AutoreleasePool::alloc()->init();
return p;
}
void new_stream(Stream stream) {
thread_autorelease_pool();
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).new_queue(stream.index);
}
}
} // namespace mlx::core::metal

10
mlx/backend/metal/fft.cpp Normal file
View File

@@ -0,0 +1,10 @@
#include "mlx/primitives.h"
namespace mlx::core {
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
throw std::runtime_error("[FFT] NYI for Metal backend.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,83 @@
set(
HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
)
set(
KERNELS
"arange"
"arg_reduce"
"binary"
"conv"
"copy"
"gemm"
"gemv"
"random"
"reduce"
"scan"
"softmax"
"sort"
"unary"
"indexing"
)
function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
set(HEADERS_PADDED ${HEADERS})
if(${KERNEL} STREQUAL "gemm")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h)
endif()
if(${KERNEL} STREQUAL "conv")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal -Wall -Wextra
-fno-fast-math
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${KERNEL}.air
DEPENDS ${SRCFILE} ${HEADERS_PADDED}
OUTPUT ${KERNEL}.air
COMMENT "Building ${KERNEL}.air"
VERBATIM
)
endfunction(build_kernel)
foreach(KERNEL ${KERNELS})
build_kernel(${KERNEL})
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
endforeach()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
VERBATIM
)
add_custom_target(
mlx-metallib
DEPENDS
${MLX_METAL_PATH}/mlx.metallib
)
add_dependencies(
mlx
mlx-metallib
)
# Install metallib
include(GNUInstallDirs)
install(
FILES ${MLX_METAL_PATH}/mlx.metallib
DESTINATION ${CMAKE_INSTALL_LIBDIR}
COMPONENT metallib
)

View File

@@ -0,0 +1,17 @@
#pragma once
template <int NDIM>
struct MLXConvParams {
const int N; // Batch size
const int C; // In channels
const int O; // Out channels
const int iS[NDIM]; // Input spatial dim
const int wS[NDIM]; // Weight spatial dim
const int oS[NDIM]; // Output spatial dim
const int str[NDIM]; // Kernel strides
const int pad[NDIM]; // Input padding
const int dil[NDIM]; // Kernel dilation
const size_t in_strides[NDIM + 2]; // In strides
const size_t wt_strides[NDIM + 2]; // Wt strides
const size_t out_strides[NDIM + 2]; // Out strides
};

View File

@@ -0,0 +1,253 @@
#include <metal_atomic>
#include <metal_texture>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<device IdxT*, NIDX> buffers [[id(0)]];
device int* shapes [[id(NIDX + 1)]];
device size_t* strides [[id(NIDX + 2)]];
const int ndim [[id(NIDX + 3)]];
};
template <typename IdxT>
inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
template <typename T, typename IdxT, int NIDX>
[[kernel]] void gather(
const device T *src [[buffer(0)]],
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
device T *out [[buffer(2)]],
const device int *src_shape [[buffer(3)]],
const device size_t *src_strides [[buffer(4)]],
const device size_t& src_ndim [[buffer(5)]],
const device int *slice_sizes [[buffer(6)]],
const device size_t& slice_size [[buffer(7)]],
const device int *axes [[buffer(8)]],
uint gid [[thread_position_in_grid]]) {
auto ind_idx = gid / slice_size;
auto ind_offset = gid % slice_size;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
out[gid] = src[src_idx + src_offset];
}
#define instantiate_gather4(name, src_type, ind_type, nindex) \
template [[host_name("gather" name "_" #nindex)]] \
[[kernel]] void gather( \
const device src_type *src [[buffer(0)]], \
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
device src_type *out [[buffer(2)]], \
const device int *src_shape [[buffer(3)]], \
const device size_t *src_strides [[buffer(4)]], \
const device size_t& src_ndim [[buffer(5)]], \
const device int *slice_sizes [[buffer(6)]], \
const device size_t& slice_size [[buffer(7)]], \
const device int* axes [[buffer(8)]], \
uint gid [[thread_position_in_grid]]);
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX>
[[kernel]] void scatter(
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const device int *upd_shape [[buffer(3)]],
const device size_t *upd_strides [[buffer(4)]],
const device size_t& upd_ndim [[buffer(5)]],
const device size_t& upd_size [[buffer(6)]],
const device int *out_shape [[buffer(7)]],
const device size_t *out_strides [[buffer(8)]],
const device size_t& out_ndim [[buffer(9)]],
const device int* axes [[buffer(10)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid / upd_size;
auto ind_offset = gid % upd_size;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out + out_idx + out_offset, updates[upd_idx]);
}
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
template [[host_name("scatter" name "_" #nindex)]] \
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
const device type *updates [[buffer(1)]], \
device mlx_atomic<type> *out [[buffer(2)]], \
const device int *upd_shape [[buffer(3)]], \
const device size_t *upd_strides [[buffer(4)]], \
const device size_t& upd_ndim [[buffer(5)]], \
const device size_t& upd_size [[buffer(6)]], \
const device int *out_shape [[buffer(7)]], \
const device size_t *out_strides [[buffer(8)]], \
const device size_t& out_ndim [[buffer(9)]], \
const device int* axes [[buffer(10)]], \
uint gid [[thread_position_in_grid]]);
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,174 @@
#pragma once
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
union bool4_or_uint {
bool4 b;
unsigned int i;
};
struct None {
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_store_explicit(out, val, offset);
}
};
struct And {
bool simd_reduce(bool val) {
return simd_all(val);
};
static constexpr constant bool init = true;
void atomic_update(
device mlx_atomic<unsigned int>* out,
bool val,
int elem_idx,
int offset = 0) {
if (!val) {
bool4_or_uint update;
update.b = {true, true, true, true};
update.b[elem_idx] = false;
mlx_atomic_fetch_and_explicit(out, update.i, offset);
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
if (!val) {
mlx_atomic_store_explicit(out, val, offset);
}
}
// Non atomic update
void update(device bool* out, bool val) {
*out &= val;
}
// Operator
bool operator()(bool a, bool b) {
return a && b;
}
};
struct Or {
bool simd_reduce(bool val) {
return simd_any(val);
};
static constexpr constant bool init = false;
void atomic_update(
device mlx_atomic<unsigned int>* out,
bool val,
int elem_idx,
int offset = 0) {
if (val) {
bool4_or_uint update;
update.b = {false, false, false, false};
update.b[elem_idx] = true;
mlx_atomic_fetch_or_explicit(out, update.i, offset);
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
if (val) {
mlx_atomic_store_explicit(out, val, offset);
}
}
// Non atomic update
void update(device bool* out, bool val) {
*out |= val;
}
// Operator
bool operator()(bool a, bool b) {
return a || b;
}
};
template <typename U>
struct Sum {
template <typename T>
T simd_reduce(T val) {
return simd_sum(val);
};
static constexpr constant U init = U(0);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_add_explicit(out, val, offset);
}
// Operator
U operator()(U a, U b) {
return a + b;
}
};
template <typename U>
struct Prod {
template <typename T>
T simd_reduce(T val) {
return simd_product(val);
};
static constexpr constant U init = U(1);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_mul_explicit(out, val, offset);
}
// Operator
U operator()(U a, U b) {
return a * b;
}
};
template <typename U>
struct Min {
template <typename T>
T simd_reduce(T val) {
return simd_min(val);
};
static constexpr constant U init = Limits<U>::max;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_min_explicit(out, val, offset);
}
// Operator
U operator()(U a, U b) {
return a < b ? a : b;
}
};
template <typename U>
struct Max {
template <typename T>
T simd_reduce(T val) {
return simd_max(val);
};
static constexpr constant U init = Limits<U>::min;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_max_explicit(out, val, offset);
}
// Operator
U operator()(U a, U b) {
return a > b ? a : b;
}
};

View File

@@ -0,0 +1,492 @@
#include <metal_math>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename U>
struct CumSum {
static constexpr constant U init = static_cast<U>(0);
template <typename T>
U operator()(U a, T b) {
return a + b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_sum(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_sum(x);
}
};
template <typename U>
struct CumProd {
static constexpr constant U init = static_cast<U>(1.0f);
template <typename T>
U operator()(U a, T b) {
return a * b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_product(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_product(x);
}
};
template <>
struct CumProd<bool> {
static constexpr constant bool init = true;
template <typename T>
bool operator()(bool a, T b) {
return a & static_cast<bool>(b);
}
bool simd_scan(bool x) {
for (int i=1; i<=16; i*=2) {
bool other = simd_shuffle_up(x, i);
x &= other;
}
return x;
}
bool simd_exclusive_scan(bool x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMax {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return (a >= b) ? a : b;
}
U simd_scan(U x) {
for (int i=1; i<=16; i*=2) {
U other = simd_shuffle_up(x, i);
x = (x >= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMin {
static constexpr constant U init = Limits<U>::max;
template <typename T>
U operator()(U a, T b) {
return (a <= b) ? a : b;
}
U simd_scan(U x) {
for (int i=1; i<=16; i*=2) {
U other = simd_shuffle_up(x, i);
x = (x <= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T * input) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
values[N_READS-i-1] = input[i];
}
} else {
for (int i=0; i<N_READS; i++) {
values[i] = input[i];
}
}
}
template <typename T, typename U, int N_READS, bool reverse>
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
}
} else {
for (int i=0; i<N_READS; i++) {
values[i] = (start + i < total) ? input[i] : init;
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_unsafe(U values[N_READS], device U * out) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
out[i] = values[N_READS-i-1];
}
} else {
for (int i=0; i<N_READS; i++) {
out[i] = values[i];
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_safe(U values[N_READS], device U * out, int start, int total) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
if (start + N_READS - i - 1 < total) {
out[i] = values[N_READS-i-1];
}
}
} else {
for (int i=0; i<N_READS; i++) {
if (start + i < total) {
out[i] = values[i];
}
}
}
}
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
[[kernel]] void contiguous_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t & axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Position the pointers
in += (gid / lsize) * axis_size;
out += (gid / lsize) * axis_size;
// Compute the number of simd_groups
uint simd_groups = lsize / simd_size;
// Allocate memory
U prefix = Op::init;
U values[N_READS];
threadgroup U simdgroup_sums[32];
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
// Read block
// Compute inclusive scan of the block
// Compute inclusive scan per thread
// Compute exclusive scan of thread sums in simdgroup
// Write simdgroup sums in SM
// Compute exclusive scan of simdgroup sums
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) {
// Compute the block offset
uint offset = r*lsize*N_READS + lid*N_READS;
// Read the values
if (reverse) {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
}
} else {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
} else {
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
}
}
// Compute an inclusive scan per thread
for (int i=1; i<N_READS; i++) {
values[i] = op(values[i], values[i-1]);
}
// Compute exclusive scan of thread sums
U prev_thread = op.simd_exclusive_scan(values[N_READS-1]);
// Write simdgroup_sums to SM
if (simd_lane_id == simd_size - 1) {
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute exclusive scan of simdgroup_sums
if (simd_group_id == 0) {
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
simdgroup_sums[simd_lane_id] = prev_simdgroup;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute the output
for (int i=0; i<N_READS; i++) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
values[i] = op(values[i], prev_thread);
}
// Write the values
if (reverse) {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
} else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[axis_size-1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
} else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
}
}
} else {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset);
} else {
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[0] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
} else {
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
}
}
}
// Share the prefix
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
simdgroup_sums[0] = values[N_READS-1];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
prefix = simdgroup_sums[0];
}
}
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
[[kernel]] void strided_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t & axis_size [[buffer(2)]],
const constant size_t & stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]) {
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS*32*32 + N_READS*32];
U values[N_READS];
U prefix[N_READS];
for (int i=0; i<N_READS; i++) {
prefix[i] = Op::init;
}
// Compute offsets
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
for (uint j=0; j<axis_size; j+=simd_size) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint check_index_y = index_y;
uint index_x = global_index_x + lid.x * N_READS;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
}
} else {
for (int i=0; i<N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i=0; i<N_READS; i++) {
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
}
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i=0; i<N_READS; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size-1);
}
// Write to SM
for (int i=0; i<N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write to device memory
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
} else {
for (int i=0; i<N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
}
}
}
if (reverse) {
index_y -= 1;
check_index_y += 1;
} else {
index_y += 1;
check_index_y += 1;
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
} else {
for (int i=0; i<N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
}
}
}
}
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contiguous_scan_" #name)]] \
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t & axis_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("strided_scan_" #name)]] \
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t & axis_size [[buffer(2)]], \
const constant size_t & stride [[buffer(3)]], \
uint2 gid [[thread_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)
//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4)
instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4)
instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4)
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4)
instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4)
//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4)
instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4)
instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4)
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4)
instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4)
//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4)
instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4)
instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4)
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4)
instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4)
//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4)
instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4)
instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4)
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)

View File

@@ -0,0 +1,88 @@
#include <cstdlib>
#include <future>
#include <memory>
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::metal {
int max_ops_per_buffer() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
return atoi(buff_str);
} else {
return 10;
}
};
static int max_ops_per_buffer_ = get_val();
return max_ops_per_buffer_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
MTL::CommandBuffer* increment_command_buffer(Stream s) {
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
if (command_buffer == nullptr ||
d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
if (command_buffer != nullptr) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); });
d.commit_command_buffer(s.index);
}
command_buffer = d.new_command_buffer(s.index);
}
d.increment_command_buffer_ops(s.index);
return command_buffer;
}
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph) {
auto task =
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
for (auto& d : deps) {
d.wait();
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
arr.primitive().eval_gpu(arr.inputs(), arr);
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[retain_graph, s, arr, p = std::move(p)](
MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
p->set_value();
// Signal this thread to clear the pool on a synchroniztion.
scheduler::enqueue(s, []() {
thread_autorelease_pool()->release();
thread_autorelease_pool() =
NS::AutoreleasePool::alloc()->init();
});
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
});
}
};
return task;
}
} // namespace mlx::core::metal

28
mlx/backend/metal/metal.h Normal file
View File

@@ -0,0 +1,28 @@
#pragma once
#include <future>
#include <memory>
#include <vector>
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::metal {
constexpr bool is_available() {
#ifdef _METAL_
return true;
#else
return false;
#endif
}
void new_stream(Stream stream);
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph);
} // namespace mlx::core::metal

View File

@@ -0,0 +1,82 @@
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!is_floating_point(out.dtype())) {
throw std::runtime_error(
"[softmax] Does not support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
if (x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& in = check_input(inputs[0]);
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = SOFTMAX_N_READS;
const int looped_limit = SOFTMAX_LOOPED_LIMIT;
std::string op_name = "softmax_";
if (axis_size > looped_limit) {
op_name += "looped_";
}
op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

336
mlx/backend/metal/sort.cpp Normal file
View File

@@ -0,0 +1,336 @@
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <bool ARGSORT>
void single_block_sort(
const Stream& s,
metal::Device& d,
const array& in,
array& out,
int axis,
int bn,
int tn) {
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
std::vector<size_t> nc_str = in.strides();
nc_str.erase(nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis);
int nc_dim = nc_shape.size();
int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis];
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
// Check if remaining strides are contiguous
bool contiguous_write = true;
if (axis != in.ndim() - 1 && axis != 0) {
for (int i = 0; i < nc_str.size() - 1; ++i) {
size_t expected = nc_str[i + 1] * nc_str[i + 1];
contiguous_write &= (nc_str[i] == expected);
}
}
// Prepare kernel name
std::ostringstream kname;
if (ARGSORT) {
kname << "arg_";
}
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
<< "_bn" << bn << "_tn" << tn;
if (!contiguous_write) {
kname << "_nc";
}
// Prepare command encoder
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Set inputs
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
if (contiguous_write) {
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
} else {
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
}
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
template <bool ARGSORT>
void multi_block_sort(
const Stream& s,
metal::Device& d,
const array& in,
array& out,
int axis,
int bn,
int tn,
int n_blocks) {
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
std::vector<size_t> nc_str = in.strides();
nc_str.erase(nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis);
int nc_dim = nc_shape.size();
int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis];
// Make temporary copies
array dev_vals_0({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
array dev_vals_1({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
array dev_idxs_0({n_rows, size_sorted_axis}, uint32, nullptr, {});
array dev_idxs_1({n_rows, size_sorted_axis}, uint32, nullptr, {});
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
// Do allocations
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
block_partitions.set_data(
allocator::malloc_or_wait(block_partitions.nbytes()));
std::vector<array> copies = {
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
// Prepare command encoder
auto compute_encoder = d.get_command_encoder(s.index);
// Do blockwise sort
{
std::ostringstream kname;
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, dev_vals_0, 1);
set_array_buffer(compute_encoder, dev_idxs_0, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do merges
bool ping = false;
array dev_vals_in = dev_vals_0;
array dev_idxs_in = dev_idxs_0;
array dev_vals_out = dev_vals_1;
array dev_idxs_out = dev_idxs_1;
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
ping = !ping;
// Do partiton
{
std::ostringstream kname;
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, block_partitions, 0);
set_array_buffer(compute_encoder, dev_vals_in, 1);
set_array_buffer(compute_encoder, dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do merge
{
std::ostringstream kname;
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, block_partitions, 0);
set_array_buffer(compute_encoder, dev_vals_in, 1);
set_array_buffer(compute_encoder, dev_idxs_in, 2);
set_array_buffer(compute_encoder, dev_vals_out, 3);
set_array_buffer(compute_encoder, dev_idxs_out, 4);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
}
// Copy outputs with appropriate strides
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
if (axis == strided_out_arr.ndim() - 1) {
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
} else {
std::vector<int> strided_out_shape = strided_out_arr.shape();
std::vector<size_t> strided_out_str = strided_out_arr.strides();
int out_axis_shape = strided_out_shape[axis];
int out_axis_str = strided_out_str[axis];
strided_out_shape.erase(strided_out_shape.begin() + axis);
strided_out_str.erase(strided_out_str.begin() + axis);
strided_out_shape.push_back(out_axis_shape);
strided_out_str.push_back(out_axis_str);
array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {});
strided_out_slice.copy_shared_buffer(
strided_out_arr,
strided_out_str,
strided_out_arr.flags(),
strided_out_arr.size(),
0);
copy_gpu_inplace(strided_out_slice, out, CopyType::General, s);
}
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
template <bool ARGSORT>
void gpu_merge_sort(
const Stream& s,
metal::Device& d,
const array& in,
array& out,
int axis_) {
// Get size info
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
int size_sorted_axis = in.shape(axis);
// Get kernel size
int tn = 8;
int bn = 128;
int potential_bn = (size_sorted_axis + tn - 1) / tn;
if (potential_bn > 256) {
bn = 512;
} else if (potential_bn > 128) {
bn = 256;
} else {
bn = 128;
}
if (bn == 512 && size_of(in.dtype()) > 4) {
bn = 256;
}
int n_per_block = bn * tn;
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
if (n_blocks > 1) {
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
} else {
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
}
}
} // namespace
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<true>(s, d, in, out, axis_);
}
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<false>(s, d, in, out, axis_);
}
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct arg partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<true>(s, d, in, out, axis_);
}
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<false>(s, d, in, out, axis_);
}
} // namespace mlx::core