awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
)
if (NOT MLX_METAL_PATH)
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
target_compile_definitions(
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")

113
mlx/backend/metal/copy.cpp Normal file
View File

@@ -0,0 +1,113 @@
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(in, out);
auto& strides_in = strides[0];
auto& strides_out = strides[1];
auto& d = metal::device(s.device);
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << "scopy";
break;
case CopyType::Vector:
kname << "vcopy";
break;
case CopyType::General:
kname << "gcopy";
break;
case CopyType::GeneralGeneral:
kname << "ggcopy";
break;
}
kname << type_to_name(in) << type_to_name(out);
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
size_t ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
}
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
}
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
}
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
int rest = in.size() / (dim0 * dim1);
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,81 @@
#pragma once
#include <Metal/Metal.hpp>
#include <functional>
#include <mutex>
#include <string>
#include <unordered_map>
#include <dlfcn.h>
#include <filesystem>
#include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
class Device {
public:
Device();
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
~Device();
MTL::Device* mtl_device() {
return device_;
};
void new_queue(int index);
MTL::CommandBuffer* new_command_buffer(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index);
void commit_command_buffer(int index);
MTL::ComputeCommandEncoder* get_command_encoder(int index);
void end_encoding(int index);
void register_library(
const std::string& lib_name,
const std::string& lib_path);
void register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
MTL::ComputePipelineState* get_kernel(
const std::string& name,
const std::string& lib_name = "mlx");
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
private:
NS::AutoreleasePool* pool_;
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_;
};
Device& device(mlx::core::Device);
NS::AutoreleasePool*& thread_autorelease_pool();
} // namespace mlx::core::metal

View File

@@ -0,0 +1,296 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
} // namespace
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0];
int nidx = inputs.size() - 1;
if (nidx > METAL_MAX_INDEX_ARRAYS) {
std::ostringstream msg;
msg << "[Gather::eval_gpu] Gathering with more than "
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
throw std::runtime_error(msg.str());
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
size_t slice_size = 1;
for (auto s : slice_sizes_) {
slice_size *= s;
}
size_t ndim = src.ndim();
size_t nthreads = out.size();
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy(
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(),
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy(
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(),
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
}
// Allocate the argument bufer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
// Set all the buffers
set_array_buffer(compute_encoder, src, 0);
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Cleanup temporaries
arg_enc->release();
d.get_command_buffer(s.index)->addCompletedHandler(
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
allocator::free(arg_buf);
allocator::free(idx_shapes_buf);
allocator::free(idx_strides_buf);
});
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (size_of(out.dtype()) == 8) {
std::ostringstream msg;
msg << "[Scatter::eval_gpu] Does not support " << out.dtype();
throw std::invalid_argument(msg.str());
}
int nidx = axes_.size();
if (nidx > METAL_MAX_INDEX_ARRAYS) {
std::ostringstream msg;
msg << "[Scatter::eval_gpu] Gathering with more than "
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
throw std::runtime_error(msg.str());
}
// Copy src into out
auto copy_type =
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy_gpu(inputs[0], out, copy_type);
// Get stream
auto& s = stream();
auto& d = metal::device(s.device);
// Get kernel name
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "scatter" << type_to_name(out) << idx_type_name;
switch (reduce_type_) {
case Scatter::None:
kname << "_none";
break;
case Scatter::Sum:
kname << "_sum";
break;
case Scatter::Prod:
kname << "_prod";
break;
case Scatter::Max:
kname << "_max";
break;
case Scatter::Min:
kname << "_min";
break;
}
kname << "_" << nidx;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto& upd = inputs.back();
size_t nthreads = upd.size();
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy(
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(),
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy(
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(),
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
}
// Allocate the argument bufer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
size_t upd_ndim = upd.ndim();
size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(upd.strides().data(), upd_ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
size_t out_ndim = out.ndim();
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(out.strides().data(), out_ndim * sizeof(size_t), 8);
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Cleanup temporaries
arg_enc->release();
d.get_command_buffer(s.index)->addCompletedHandler(
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
allocator::free(arg_buf);
allocator::free(idx_shapes_buf);
allocator::free(idx_strides_buf);
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,320 @@
#pragma once
#include <metal_atomic>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Atomic utils
///////////////////////////////////////////////////////////////////////////////
#pragma METAL internals : enable
template <typename T>
constexpr constant bool is_metal_atomic = _disjunction<
is_same<T, int>,
is_same<T, uint>,
is_same<T, ulong>,
is_same<T, float>>::value;
#pragma METAL internals : disable
template <typename T, typename = void>
struct mlx_atomic {
atomic<uint> val;
};
template <typename T>
struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
atomic<T> val;
};
///////////////////////////////////////////////////////////////////////////////
// Native metal atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
T expected = mlx_atomic_load_explicit(object, offset);
while (!mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val * expected, offset)) {
}
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread T* expected,
T val,
int offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
val,
memory_order_relaxed,
memory_order_relaxed);
}
// Specialization for float since it does not atomic_fetch_min_explicit
template <>
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
device mlx_atomic<float>* object,
float val,
int offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val < expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val, offset)) {
return;
}
}
}
// Specialization for float since it does not atomic_fetch_max_explicit
template <>
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
device mlx_atomic<float>* object,
float val,
int offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val > expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val, offset)) {
return;
}
}
}
///////////////////////////////////////////////////////////////////////////////
// Custom atomics
///////////////////////////////////////////////////////////////////////////////
namespace {
template <typename T>
constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
template <typename T>
union uint_or_packed {
T val[packing_size<T>];
uint bits;
};
template <typename T, typename Op>
struct mlx_atomic_update_helper {
uint operator()(uint_or_packed<T> init, T update, int elem_offset) {
Op op;
init.val[elem_offset] = op(update, init.val[elem_offset]);
return init.bits;
}
};
template <typename T, typename Op>
METAL_FUNC void mlx_atomic_update_and_store(
device mlx_atomic<T>* object,
T update,
int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
mlx_atomic_update_helper<T, Op> helper;
uint_or_packed<T> expected;
expected.bits =
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
while (Op::condition(update, expected.val[elem_offset]) &&
!mlx_atomic_compare_exchange_weak_explicit(
object,
&(expected.bits),
helper(expected, update, elem_offset),
pack_offset)) {
}
}
template <typename T>
struct __None {
static bool condition(T a, T b) {
#pragma unused(a)
#pragma unused(b)
return true;
}
T operator()(T a, T b) {
#pragma unused(b)
return a;
}
};
template <typename T>
struct __Add {
static bool condition(T a, T b) {
#pragma unused(a)
#pragma unused(b)
return true;
}
T operator()(T a, T b) {
return a + b;
}
};
template <typename T>
struct __Mul {
static bool condition(T a, T b) {
#pragma unused(a)
return b != 0;
}
T operator()(T a, T b) {
return a * b;
}
};
template <typename T>
struct __Max {
static bool condition(T a, T b) {
return a > b;
}
T operator()(T a, T b) {
return max(a, b);
}
};
template <typename T>
struct __Min {
static bool condition(T a, T b) {
return a < b;
}
T operator()(T a, T b) {
return min(a, b);
}
};
} // namespace
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
int pack_offset = offset / sizeof(T);
int elem_offset = offset % sizeof(T);
uint_or_packed<T> packed_val;
packed_val.bits =
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
return packed_val.val[elem_offset];
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = __UINT32_MAX__;
identity.val[elem_offset] = val;
atomic_fetch_and_explicit(
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = 0;
identity.val[elem_offset] = val;
atomic_fetch_or_explicit(
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread uint* expected,
uint val,
int offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
val,
memory_order_relaxed,
memory_order_relaxed);
}

View File

@@ -0,0 +1,315 @@
#pragma once
#include <metal_stdlib>
using namespace metal;
#if defined(__HAVE_BFLOAT__)
typedef bfloat bfloat16_t;
#else
/////////////////////////////////////////////////////////////////////////////
// Helpers
/////////////////////////////////////////////////////////////////////////////
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
// Check for nan
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
_fp_encoding_traits<float>::inf_mask) {
return uint16_t(as_type<uint32_t>(0x7FC0));
}
// Take bits
uint32_t float_bits = as_type<uint32_t>(x);
// Round to nearest even
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
// Take upper 16 bits
return float_bits >> 16;
}
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
// Upper 16 bits are the data and lower 16 bits are 0s
return as_type<float>((uint32_t)x << 16);
}
struct _MLX_BFloat16;
template <typename T>
static constexpr constant bool can_convert_to_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
/////////////////////////////////////////////////////////////////////////////
// Bfloat struct
/////////////////////////////////////////////////////////////////////////////
struct _MLX_BFloat16 {
/////////////////////////////////////////////////////////////////////////////
// Constructors
uint16_t bits_;
_MLX_BFloat16() thread = default;
_MLX_BFloat16() threadgroup = default;
_MLX_BFloat16() device = default;
_MLX_BFloat16() constant = default;
struct bits_to_bfloat_struct {};
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
return bits_to_bfloat_struct();
}
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
: bits_(bits) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions to bfloat
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) device
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions from bfloat
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const thread {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const threadgroup {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const device {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const constant {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
};
/////////////////////////////////////////////////////////////////////////////
// Bfloat operators
/////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////
// Unary ops
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
return -static_cast<float>(x);
}
/////////////////////////////////////////////////////////////////////////////
// Binary operators
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
} \
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
/////////////////////////////////////////////////////////////////////////////
// Arithmetic Operators
#define bfloat_binop(_op_, _operator_) \
bfloat_binop_base( \
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(_op_, _operator_, float, float, float); \
bfloat_binop_helper(_op_, _operator_, float, half, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
bfloat_binop(+, operator+);
bfloat_binop(-, operator-);
bfloat_binop(*, operator*);
bfloat_binop(/, operator/);
/////////////////////////////////////////////////////////////////////////////
// Comparison ops
#define bfloat_compop(__op__, __operator__) \
bfloat_binop_base( \
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
bfloat_compop(>, operator>);
bfloat_compop(<, operator<);
bfloat_compop(>=, operator>=);
bfloat_compop(<=, operator<=);
bfloat_compop(==, operator==);
bfloat_compop(!=, operator!=);
#undef bfloat_compop
#undef bfloat_binop_base
#undef bfloat_binop_helper
#undef bfloat_binop
/////////////////////////////////////////////////////////////////////////////
// Inplace Operators
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, itype rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
} \
constexpr METAL_FUNC addr_space itype& __operator__( \
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
#define bfloat_inplace_op(itype) \
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
bfloat_inplace_op(float);
bfloat_inplace_op(half);
bfloat_inplace_op(int16_t);
bfloat_inplace_op(int32_t);
bfloat_inplace_op(int64_t);
bfloat_inplace_op(uint16_t);
bfloat_inplace_op(uint32_t);
bfloat_inplace_op(uint64_t);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
#undef bfloat_inplace_op
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
bfloat_inplace_op_helper(__op__, __operator__, device); \
bfloat_inplace_op_helper(__op__, __operator__, thread); \
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
bfloat_inplace_op_addr_space_helper(+, operator+=);
bfloat_inplace_op_addr_space_helper(-, operator-=);
bfloat_inplace_op_addr_space_helper(*, operator*=);
bfloat_inplace_op_addr_space_helper(/, operator/=);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
/////////////////////////////////////////////////////////////////////////////
// Bfloat typedef
/////////////////////////////////////////////////////////////////////////////
typedef struct _MLX_BFloat16 bfloat16_t;
/////////////////////////////////////////////////////////////////////////////
// Bfloat numeric limits
/////////////////////////////////////////////////////////////////////////////
#pragma METAL internals : enable
namespace metal {
template <>
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
static constexpr constant int digits = 8;
static constexpr constant int digits10 = 2;
static constexpr constant int max_digits10 = 4;
static constexpr constant int radix = 2;
static constexpr constant int min_exponent = -125;
static constexpr constant int min_exponent10 = -37;
static constexpr constant int max_exponent = 128;
static constexpr constant int max_exponent10 = 38;
static constexpr bfloat16_t min() {
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t lowest() {
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t max() {
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t epsilon() {
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t round_error() {
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t infinity() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t quiet_NaN() {
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t signaling_NaN() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t denorm_min() {
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
}
};
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
return x != x;
}
} // namespace metal
#pragma METAL internals : disable
#endif // defined(__HAVE_BFLOAT__)
#include "mlx/backend/metal/kernels/bf16_math.h"

View File

@@ -0,0 +1,369 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct Add {
template <typename T> T operator()(T x, T y) { return x + y; }
};
struct Divide {
template <typename T> T operator()(T x, T y) { return x / y; }
};
struct Equal {
template <typename T> bool operator()(T x, T y) { return x == y; }
};
struct NaNEqual {
template <typename T> bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real)
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T> bool operator()(T x, T y) { return x > y; }
};
struct GreaterEqual {
template <typename T> bool operator()(T x, T y) { return x >= y; }
};
struct Less {
template <typename T> bool operator()(T x, T y) { return x < y; }
};
struct LessEqual {
template <typename T> bool operator()(T x, T y) { return x <= y; }
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf) ? maxval :
(maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x >= y ? x : y;
}
};
struct Minimum {
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x <= y ? x : y;
}
};
struct Multiply {
template <typename T> T operator()(T x, T y) { return x * y; }
};
struct NotEqual {
template <typename T> bool operator()(T x, T y) { return x != y; }
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T> T operator()(T x, T y) { return x - y; }
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_s2s(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_ss(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_sv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vs(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_op_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] \
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] \
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] \
[[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op)
instantiate_binary_types(add, Add)
instantiate_binary_float(div, Divide)
instantiate_binary_types_bool(eq, Equal)
instantiate_binary_types_bool(ge, Greater)
instantiate_binary_types_bool(geq, GreaterEqual)
instantiate_binary_types_bool(le, Less)
instantiate_binary_types_bool(leq, LessEqual)
instantiate_binary_types_bool(neq, NotEqual)
instantiate_binary_float(lae, LogAddExp)
instantiate_binary_types(max, Maximum)
instantiate_binary_types(min, Minimum)
instantiate_binary_types(mul, Multiply)
instantiate_binary_types(sub, Subtract)
instantiate_binary_types(pow, Power)
// NaNEqual only needed for floating point types with boolean output
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)

View File

@@ -0,0 +1,110 @@
#pragma once
#include <metal_stdlib>
using namespace metal;
struct complex64_t;
template <typename T>
static constexpr constant bool can_convert_to_complex64 =
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_complex64 =
!is_same_v<T, complex64_t> &&
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
struct complex64_t {
float real;
float imag;
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
// Conversions to complex64_t
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) thread : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) device : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) constant : real(x), imag(0) {}
// Converstions from complex64_t
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const thread {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const threadgroup {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const device {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const constant {
return static_cast<T>(real);
}
};
constexpr complex64_t operator-(complex64_t x) {
return {-x.real, -x.imag};
}
constexpr bool operator>=(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
}
constexpr bool operator>(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
}
constexpr bool operator<=(complex64_t a, complex64_t b) {
return operator>=(b, a);
}
constexpr bool operator<(complex64_t a, complex64_t b) {
return operator>(b, a);
}
constexpr bool operator==(complex64_t a, complex64_t b) {
return a.real == b.real && a.imag == b.imag;
}
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
return {a.real - b.real, a.imag - b.imag};
}
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
}

View File

@@ -0,0 +1,14 @@
#pragma once
#ifdef __METAL__
#define MTL_CONST constant
#else
#define MTL_CONST
#endif
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;

View File

@@ -0,0 +1,479 @@
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DInputBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BM;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
int offsets_n[n_rows];
int offsets_oh[n_rows];
int offsets_ow[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bj),
params(params_),
weight_h(0),
weight_w(0) {
int out_n_pixels = params.oS[0] * params.oS[1];
for (int i = 0; i < n_rows; ++i) {
int offset_nhw = tid.y * BM + bi + i * bstride;
offsets_n[i] = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
offsets_oh[i] = hw / params.oS[1];
offsets_ow[i] = hw % params.oS[1];
}
(void)lid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
#pragma clang loop unroll(full)
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
int n = offsets_n[i];
int oh = offsets_oh[i];
int ow = offsets_ow[i];
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
// Read from input if in bounds
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const device T* curr_src = src + n * params.in_strides[0] +
ih * params.in_strides[1] + iw * params.in_strides[2];
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = curr_src[j];
}
}
// Zero pad otherwize
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DWeightBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BN;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_.wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj),
params(params_),
weight_h(0),
weight_w(0) {
(void)lid;
(void)tid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
const device T* curr_src =
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
///////////////////////////////////////////////////////////////////////////////
// Transforms
///////////////////////////////////////////////////////////////////////////////
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int tgp_padding_a = 0,
int tgp_padding_b = 0,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DBlockMMA {
// Warp tile size along M
MLX_MTL_CONST int TM = BM / (WM * 8);
// Warp tile size along N
MLX_MTL_CONST int TN = BN / (WN * 8);
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TN_stride = 8 * WN;
// Leading dimensions of threadgroup A, B blocks
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
// Strides of A, B along reduction axis
MLX_MTL_CONST short simd_stride_a =
transpose_a ? TM_stride : TM_stride * lda_tgp;
MLX_MTL_CONST short simd_stride_b =
transpose_b ? TN_stride * ldb_tgp : TN_stride;
// Jump between elements
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
// Offsets within threadgroup
const int tm;
const int tn;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
short sm;
short sn;
/* Constructor */
METAL_FUNC Conv2DBlockMMA(
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Iterate over BK in blocks of 8
#pragma clang loop unroll(full)
for (short kk = 0; kk < BK; kk += 8) {
short2 offset_a =
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
short2 offset_b =
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
As__ += simd_stride_a;
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
Bs__ += simd_stride_b;
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into resulr simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
simdgroup_multiply_accumulate(
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device T* C, const int ldc) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
METAL_FUNC void
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
}
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DImplicitGEMMKernel {
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
MLX_MTL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
MLX_MTL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
MLX_MTL_CONST short tgp_size = WM * WN * 32;
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
using loader_a_t =
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
using loader_b_t =
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
using mma_t = Conv2DBlockMMA<
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
tgp_padding_a,
tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const int K = params.wt_strides[0];
const int N = params.O;
B += c_col * K;
C += c_row * N + c_col;
// Prepare threadgroup memory for loading
threadgroup T* As = tgp_memory;
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
mma_op.store_result(C, N);
}
};

View File

@@ -0,0 +1,536 @@
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BROWS,
int BCOLS,
int BK,
int vec_size,
int tgp_size,
bool transpose,
bool ldK,
int tgp_padding = 0>
struct BlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
// Leading dimension for src
const int src_ld;
// Stride along reduction axis between blocks
const int tstride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tstride(
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
// Iterate over rows of block
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
// Row is in bounds, we check against column
if ((bi + i) < src_tile_dim.y) {
// Use fast thread memory for bound checks
short tmp_idx[vec_size];
T tmp_val[vec_size];
// Make sure tmp_idx only contains valid indices
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
}
// Read all valid indcies into tmp_val
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
}
// Zero out uneeded values
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
// Row is out of bounds, we just fill tgp memory with zeros
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tstride;
}
};
///////////////////////////////////////////////////////////////////////////////
// Transforms
///////////////////////////////////////////////////////////////////////////////
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int tgp_padding_a = 0,
int tgp_padding_b = 0,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct BlockMMA {
// Warp tile size along M
MLX_MTL_CONST int TM = BM / (WM * 8);
// Warp tile size along N
MLX_MTL_CONST int TN = BN / (WN * 8);
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TN_stride = 8 * WN;
// Leading dimensions of threadgroup A, B blocks
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
// Strides of A, B along reduction axis
MLX_MTL_CONST short simd_stride_a =
transpose_a ? TM_stride : TM_stride * lda_tgp;
MLX_MTL_CONST short simd_stride_b =
transpose_b ? TN_stride * ldb_tgp : TN_stride;
// Jump between elements
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
// Offsets within threadgroup
const int tm;
const int tn;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
short sm;
short sn;
/* Constructor */
METAL_FUNC BlockMMA(
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Iterate over BK in blocks of 8
#pragma clang loop unroll(full)
for (short kk = 0; kk < BK; kk += 8) {
short2 offset_a =
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
short2 offset_b =
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
As__ += simd_stride_a;
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
Bs__ += simd_stride_b;
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into resulr simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
simdgroup_multiply_accumulate(
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device T* C, const int ldc) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
METAL_FUNC void
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
}
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct GEMMKernel {
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
MLX_MTL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
MLX_MTL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
MLX_MTL_CONST short tgp_size = WM * WN * 32;
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
using loader_a_t = BlockLoader<
T,
BM,
BK,
BK,
vec_size,
tgp_size,
transpose_a,
true,
tgp_padding_a>;
using loader_b_t = BlockLoader<
T,
BK,
BN,
BK,
vec_size,
tgp_size,
transpose_b,
false,
tgp_padding_b>;
using mma_t = BlockMMA<
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
tgp_padding_a,
tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant int& M [[buffer(3)]],
const constant int& N [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& batch_stride_a [[buffer(6)]],
const constant int& batch_stride_b [[buffer(7)]],
const constant int& batch_stride_c [[buffer(8)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
// Adjust for batch
A += batch_stride_a * tid.z;
B += batch_stride_b * tid.z;
C += batch_stride_c * tid.z;
// Adjust for transpose
const int lda_dev = transpose_a ? M : K;
const int ldb_dev = transpose_b ? K : N;
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
A += transpose_a ? c_row : c_row * K;
B += transpose_b ? c_col * K : c_col;
C += c_row * N + c_col;
// Prepare threadgroup memory for loading
threadgroup T* As = tgp_memory;
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
// Prepare threadgroup loading operations
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
mma_t mma_op(simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned && K_aligned) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
mma_op.store_result(C, N);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN aligned, K unaligned loop
else if (MN_aligned && !K_aligned) {
// Main loop
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Loop tail
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(short2(K - k, BM));
loader_b.load_safe(short2(BN, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
// Store results to device memory
mma_op.store_result(C, N);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MNK unaligned loop
else { // Loop over K - unaligned case
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
if (k < K) {
loader_a.load_safe(short2(K - k, BM));
loader_b.load_safe(short2(BN, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result(C, N);
return;
} else {
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_safe(short2(BK, src_tile_dims.y));
loader_b.load_safe(short2(src_tile_dims.x, BK));
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
if (k < K) {
loader_a.load_safe(short2(K - k, src_tile_dims.y));
loader_b.load_safe(short2(src_tile_dims.x, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
threadgroup_barrier(mem_flags::mem_none);
mma_op.store_result_safe(C, N, src_tile_dims);
return;
}
}
}
};

View File

@@ -0,0 +1,302 @@
#include <metal_stdlib>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
static constant constexpr int SIMD_SIZE = 32;
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel]] void gemv(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
// Threadgroup in_vec cache
threadgroup T in_vec_block[BN][TN * 2];
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
// Block position
int out_row = (tid.x * BM + simd_gid) * TM;
// Exit simdgroup if rows out of bound
if(out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Advance matrix
mat += out_row * in_vec_size;
// Loop over in_vec in blocks of BN * TN
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Prefetch in_vector for threadgroup use
if(simd_gid == 0) {
// Main load loop
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[simd_lid][tn] = in_vec[bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[simd_lid][tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load for all rows
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
v_coeff[tn] = in_vec_block[simd_lid][tn];
}
// Per thread work loop
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
// Load for the row
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
// Accumulate results
for(int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Simdgroup accumulations
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
result[tm] = simd_sum(result[tm]);
}
// Write outputs
if(simd_lid == 0) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
}
}
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4)
instantiate_gemv_blocks(float32, float)
instantiate_gemv_blocks(float16, half)
instantiate_gemv_blocks(bfloat16, bfloat16_t)
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel]] void gemv_t(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
// Threadgroup accumulation results
threadgroup T tgp_results[BN][BM][TM];
int out_col = (tid.x * BN + lid.x) * TN;
int in_row = lid.y * TM;
// Edgecase handling
if (out_col < out_vec_size) {
// Edgecase handling
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
int bm = in_row;
for(; bm < in_vec_size; bm += BM * TM) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if(bm + TM <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
} else { // Edgecase handling
for(int tm = 0; bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
}
// Threadgroup collection
for(int i = 0; i < TN; i++) {
tgp_results[lid.x][lid.y][i] = result[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if(lid.y == 0 && out_col < out_vec_size) {
// Threadgroup accumulation
#pragma clang loop unroll(full)
for(int i = 1; i < BM; i++) {
for(int j = 0; j < TN; j++) {
result[j] += tgp_results[lid.x][i][j];
}
}
for(int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
}
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
instantiate_gemv_t_blocks(float32, float)
instantiate_gemv_t_blocks(float16, half)
instantiate_gemv_t_blocks(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,226 @@
#include <metal_atomic>
#include <metal_common>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause it is gonna be x
// will be in (-oo, 0] anyway and subsequently it will be divided by
// sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
T ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
ld[i] = in[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<T>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
T maxval = Limits<T>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
T normalizer = 0;
for (int i = 0; i < N_READS; i++) {
T exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
local_normalizer[0] = normalizer;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[i] = ld[i] * normalizer;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = ld[i] * normalizer;
}
}
}
}
template <typename T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
// Get the max and the normalizer in one go
T prevmax;
T maxval = Limits<T>::finite_min;
T normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
T vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = in[offset + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] =
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += softmax_exp(vals[i] - maxval);
}
}
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
// lsize) parts. We need to combine them.
// 1. We start by finding the max across simd groups
// 2. We then change the partial normalizers to account for a possible
// change in max
// 3. We sum all normalizers
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= softmax_exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
// Now the normalizer and max value is correct for each simdgroup. We write
// them shared memory and combine them.
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= softmax_exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
normalizer = 1 / normalizer;
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
}
}
}
}
}
#define instantiate_softmax_single_row(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_looped(name, itype) \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax(name, itype) \
instantiate_softmax_single_row(name, itype) \
instantiate_softmax_looped(name, itype)
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,818 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
using namespace metal;\
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
///////////////////////////////////////////////////////////////////////////////
// Thread-level sort
///////////////////////////////////////////////////////////////////////////////
template <typename T>
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
T w = a;
a = b;
b = w;
}
template <typename T>
struct LessThan {
static constexpr constant T init = Limits<T>::max;
METAL_FUNC bool operator()(T a, T b) {
return a < b;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short N_PER_THREAD,
typename CompareOp>
struct ThreadSort {
static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
MLX_MTL_LOOP_UNROLL
for(short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if(op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Threadgroup-level sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp>
struct BlockMergeSort {
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
short A_sz,
short B_sz,
short sort_md) {
CompareOp op;
short A_st = max(0, sort_md - B_sz);
short A_ed = min(sort_md, A_sz);
while(A_st < A_ed) {
short md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if(op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
static METAL_FUNC void merge_step(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
const threadgroup idx_t* As_idx,
const threadgroup idx_t* Bs_idx,
short A_sz,
short B_sz,
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
short a_idx = 0;
short b_idx = 0;
for(int i = 0; i < N_PER_THREAD; ++i) {
auto a = As[a_idx];
auto b = Bs[b_idx];
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
vals[i] = pred ? b : a;
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
b_idx += short(pred);
a_idx += short(!pred);
}
}
static METAL_FUNC void sort(
threadgroup val_t* tgp_vals [[threadgroup(0)]],
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location
int idx = lid.x * N_PER_THREAD;
// Load from shared memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for(int i = 0; i < N_PER_THREAD; ++i) {
thread_vals[i] = tgp_vals[idx + i];
if(ARG_SORT) {
thread_idxs[i] = tgp_idxs[idx + i];
}
}
// Per thread sort
if(idx < size_sorted_axis) {
thread_sort_t::sort(thread_vals, thread_idxs);
}
// Do merges using threadgroup memory
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
// Update threadgroup memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if(ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Find location in merge step
int merge_group = lid.x / merge_threads;
int merge_lane = lid.x % merge_threads;
int sort_sz = N_PER_THREAD * merge_threads;
int sort_st = N_PER_THREAD * merge_threads * merge_group;
// As = tgp_vals[A_st:A_ed] is sorted
// Bs = tgp_vals[B_st:B_ed] is sorted
int A_st = sort_st;
int A_ed = sort_st + sort_sz / 2;
int B_st = sort_st + sort_sz / 2;
int B_ed = sort_st + sort_sz;
const threadgroup val_t* As = tgp_vals + A_st;
const threadgroup val_t* Bs = tgp_vals + B_st;
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Find a partition of merge elements
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// of size N_PER_THREAD for each merge lane i
// C = [Ci] is sorted
int sort_md = N_PER_THREAD * merge_lane;
int partition = merge_partition(
As,
Bs,
A_sz,
B_sz,
sort_md);
As += partition;
Bs += sort_md - partition;
A_sz -= partition;
B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers
merge_step(
As,
Bs,
As_idx,
Bs_idx,
A_sz,
B_sz,
thread_vals,
thread_idxs);
}
// Write out to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if(ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Kernel sort
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<T>>
struct KernelMergeSort {
using val_t = T;
using idx_t = uint;
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device T* inp,
device U* out,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
const constant int& stride_segment_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
// Copy into threadgroup memory
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
if(ARG_SORT) {
tgp_idxs[i] = i;
}
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) {
if(ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
}
}
}
};
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
if(ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
nullptr,
tid,
lid);
}
}
constant constexpr const int zero_helper = 0;
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out += block_idx;
if(ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
nullptr,
tid,
lid);
}
}
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
#define instantiate_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
#define instantiate_block_sort_tn(itname, itype, bn) \
instantiate_block_sort_base(itname, itype, bn, 8) \
instantiate_arg_block_sort_base(itname, itype, bn, 8)
#define instantiate_block_sort_bn(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_bn(uint8, uint8_t)
instantiate_block_sort_bn(uint16, uint16_t)
instantiate_block_sort_bn(uint32, uint32_t)
instantiate_block_sort_bn(int8, int8_t)
instantiate_block_sort_bn(int16, int16_t)
instantiate_block_sort_bn(int32, int32_t)
instantiate_block_sort_bn(float16, half)
instantiate_block_sort_bn(float32, float)
instantiate_block_sort_bn(bfloat16, bfloat16_t)
#define instantiate_block_sort_long(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256)
instantiate_block_sort_long(uint64, uint64_t)
instantiate_block_sort_long(int64, int64_t)
///////////////////////////////////////////////////////////////////////////////
// Multi block merge sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device val_t* inp,
device val_t* out_vals,
device idx_t* out_idxs,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
int base_idx = tid.x * N_PER_BLOCK;
// Copy into threadgroup memory
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
tgp_idxs[i] = idx;
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
int idx = base_idx + i;
if(idx < size_sorted_axis) {
out_vals[idx] = tgp_vals[i];
out_idxs[idx] = tgp_idxs[i];
}
}
}
static METAL_FUNC int merge_partition(
const device val_t* As,
const device val_t* Bs,
int A_sz,
int B_sz,
int sort_md) {
CompareOp op;
int A_st = max(0, sort_md - B_sz);
int A_ed = min(sort_md, A_sz);
while(A_st < A_ed) {
int md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if(op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
const device val_t* inp [[buffer(0)]],
device val_t* out_vals [[buffer(1)]],
device idx_t* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out_vals += tid.y * size_sorted_axis;
out_idxs += tid.y * size_sorted_axis;
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
block_partitions += tid.y * tgp_dims.x;
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st,
dev_vals + B_st,
A_ed - A_st,
B_ed - B_st,
partition_at);
block_partitions[lid.x] = A_st + partition;
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]],
device val_t* dev_vals_out [[buffer(3)]],
device idx_t* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
using block_sort_t = typename sort_kernel::block_merge_sort_t;
block_partitions += tid.y * (num_tiles + 1);
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_out += tid.y * size_sorted_axis;
dev_idxs_out += tid.y * size_sorted_axis;
int block_idx = tid.x;
int merge_group = block_idx / merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
int A_st = block_partitions[block_idx + 0];
int A_ed = block_partitions[block_idx + 1];
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st);
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
if((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz/2);
B_ed = min(size_sorted_axis, sort_st + sort_sz);
}
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Load from global memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for(int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
if(idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
} else {
thread_vals[i] = CompareOp::init;
thread_idxs[i] = 0;
}
}
// Write to shared memory
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
tgp_vals[idx] = thread_vals[i];
tgp_idxs[idx] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Merge
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
int A_st_local = block_sort_t::merge_partition(
tgp_vals,
tgp_vals + A_sz,
A_sz,
B_sz,
sort_md_local);
int A_ed_local = A_sz;
int B_st_local = sort_md_local - A_st_local;
int B_ed_local = B_sz;
int A_sz_local = A_ed_local - A_st_local;
int B_sz_local = B_ed_local - B_st_local;
// Do merge
block_sort_t::merge_step(
tgp_vals + A_st_local,
tgp_vals + A_ed_local + B_st_local,
tgp_idxs + A_st_local,
tgp_idxs + A_ed_local + B_st_local,
A_sz_local,
B_sz_local,
thread_vals,
thread_idxs);
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
int idx = lid.x * N_PER_THREAD;
tgp_vals[idx + i] = thread_vals[i];
tgp_idxs[idx + i] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) {
int idx = base_idx + i;
if(idx < size_sorted_axis) {
dev_vals_out[idx] = tgp_vals[i];
dev_idxs_out[idx] = tgp_idxs[i];
}
}
}
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
device vtype* out_vals [[buffer(1)]], \
device itype* out_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
instantiate_multi_block_sort_base(uint8, uint8_t)
instantiate_multi_block_sort_base(uint16, uint16_t)
instantiate_multi_block_sort_base(uint32, uint32_t)
instantiate_multi_block_sort_base(int8, int8_t)
instantiate_multi_block_sort_base(int16, int16_t)
instantiate_multi_block_sort_base(int32, int32_t)
instantiate_multi_block_sort_base(float16, half)
instantiate_multi_block_sort_base(float32, float)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
#define instantiate_multi_block_sort_long(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
instantiate_multi_block_sort_long(uint64, uint64_t)
instantiate_multi_block_sort_long(int64, int64_t)

View File

@@ -0,0 +1,244 @@
#pragma once
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/complex.h"
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
template <typename U>
struct Limits {
static const constant U max;
static const constant U min;
static const constant U finite_max;
static const constant U finite_min;
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = metal::numeric_limits<type>::max(); \
static constexpr constant type min = metal::numeric_limits<type>::min(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
metal::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = \
metal::numeric_limits<type>::infinity(); \
static constexpr constant type min = \
-metal::numeric_limits<type>::infinity(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
-metal::numeric_limits<type>::max(); \
};
instantiate_float_limit(half);
instantiate_float_limit(float);
instantiate_float_limit(bfloat16_t);
template <>
struct Limits<bool> {
static constexpr constant bool max = true;
static constexpr constant bool min = false;
};
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
inline size_t elem_to_loc(
uint elem,
device const int* shape,
device const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
inline size_t elem_to_loc(
uint elem,
constant const int* shape,
constant const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <int NDIM>
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM]) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline size_t elem_to_loc_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t strides[NDIM]) {
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
return elem * stride;
}
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
}
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
// Non templated version to handle arbitrary dims
inline size_t elem_to_loc(
uint3 elem,
constant const int* shape,
constant const size_t* strides,
int ndim) {
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline uint elem_to_loc_nd(
uint elem,
device const int* shape,
device const size_t* strides);
template <>
inline uint elem_to_loc_nd<1>(
uint elem,
device const int* shape,
device const size_t* strides) {
return (elem % shape[0]) * strides[0];
}
template <>
inline uint elem_to_loc_nd<2>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<3>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<4>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[3]) * strides[3];
elem /= shape[3];
loc += (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Calculation utils
///////////////////////////////////////////////////////////////////////////////
/** Compute ceil((float)N/(float)M) */
inline size_t ceildiv(size_t N, size_t M) {
return (N + M - 1) / M;
}
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) {
float xp1 = 1.0f + x;
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
}
inline bfloat16_t log1p(bfloat16_t x) {
float xp1 = 1.0f + static_cast<float>(x);
bfloat16_t ret =
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
return ret;
}

View File

@@ -0,0 +1,446 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
bool use_mps() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_USE_MPS")) {
return std::string(buff_str) != "OFF";
} else {
return false;
}
};
static bool use_mps_ = get_val();
return use_mps_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
inline void mps_matmul(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
if (out.dtype() == float16) {
mps_dtype = MPS::DataTypeFloat16;
} else if (out.dtype() == bfloat16) {
mps_dtype = MPS::DataTypeBFloat16;
}
// Used batched MPSMatrixMultiplication if batch_size_out > 1
// We only accept the following cases:
// 1. Both a, b have batch_size_out matrices worth of data
// 2. Only one of a or b has batch_size_out matrices worth of data and
// the other has matrix worth of data
// The matrix dimsenisons of a and b are sure to be regularly strided
if (batch_size_out > 1) {
// No broadcasting defaults
auto batch_size_a = a.data_size() / (M * K);
auto batch_size_b = b.data_size() / (K * N);
auto matrix_stride_a = M * K;
auto matrix_stride_b = K * N;
auto matrix_stride_out = M * N;
// At this point, batch_size_a, batch_size_b show the number of matrices
// in data, no broadcasted strides considered
if (batch_size_out == std::max(batch_size_a, batch_size_b)) {
// Handle simple broadcasting
if (std::min(batch_size_a, batch_size_b) == 1) {
matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a;
matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b;
batch_size_a = batch_size_out;
batch_size_b = batch_size_out;
}
// Only proceed if broadcasting between a and b is simple
// At this point, batch_size_a, batch_size_b show the number of matrices
// after broadcasting
if (batch_size_a == batch_size_b) {
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
(M * K) / lda,
lda,
batch_size_a,
lda * a.itemsize(),
(matrix_stride_a * a.itemsize()),
mps_dtype);
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
(K * N) / ldb,
ldb,
batch_size_b,
ldb * b.itemsize(),
(matrix_stride_b * b.itemsize()),
mps_dtype);
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
M,
N,
batch_size_out,
N * out.itemsize(),
matrix_stride_out * out.itemsize(),
mps_dtype);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
auto command_buffer = d.get_command_buffer(s.index);
kernel->setBatchSize(batch_size_out);
kernel->setBatchStart(0);
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
command_buffer->addCompletedHandler(
[a_mat, b_mat, out_mat, kernel, copies](
MTL::CommandBuffer*) mutable {
a_mat->release();
b_mat->release();
out_mat->release();
kernel->release();
copies.clear();
});
return;
}
}
}
// Schedule as many calls to MPSMatrixMultiplication as needed otherwise
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype);
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype);
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
batch_size_out * M, N, N * out.itemsize(), mps_dtype);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
auto command_buffer = d.get_command_buffer(s.index);
for (int i = 0; i < batch_size_out; ++i) {
auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda;
auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb;
kernel->setLeftMatrixOrigin({a_row, 0, 0});
kernel->setRightMatrixOrigin({b_row, 0, 0});
kernel->setResultMatrixOrigin({i * static_cast<size_t>(M), 0, 0});
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
}
command_buffer->addCompletedHandler(
[a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable {
a_mat->release();
b_mat->release();
out_mat->release();
kernel->release();
copies.clear();
});
}
} // namespace
void mlx_matmul(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
// Account for batch sizes and basic broadcasting
int batch_size_a = a.data_size() / (M * K);
int batch_size_b = b.data_size() / (K * N);
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
int matrix_stride_out = M * N;
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 2ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
std::ostringstream kname;
kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Launch only 1 kernel in the case of simple batching / broadcasting
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
(batch_size_a == batch_size_b ||
std::min(batch_size_a, batch_size_b) == 1)) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims =
MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&M, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder->setBytes(&K, sizeof(int), 5);
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else { // Other launch kernels with set offsets
for (int i = 0; i < batch_size_out; ++i) {
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
compute_encoder->setBytes(&M, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder->setBytes(&K, sizeof(int), 5);
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!is_floating_point(out.dtype())) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, a_cols, a] = check_transpose(a_pre);
auto [b_transposed, b_cols, b] = check_transpose(b_pre);
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
auto batch_size_out = out.size() / (M * N);
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0;
int batch_size_vec = vec.data_size() / in_vector_len;
int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0;
// Determine dispatch kernel
int tm = 4, tn = 4;
int bm, bn, n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
bm = 8;
bn = 8;
if (out_vector_len >= 24576) {
bn = 128;
} else if (out_vector_len >= 16384) {
bn = 64;
} else if (out_vector_len >= 8192) {
bn = 16;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * tn;
kname << "gemv_t_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
bn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * tm;
kname << "gemv_" << type_to_name(out);
}
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(bn, bm, 1);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
set_array_buffer(compute_encoder, mat, 0);
set_array_buffer(compute_encoder, vec, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
d.end_encoding(s.index);
if (use_mps()) {
mps_matmul(
s,
d,
a,
b,
out,
M,
N,
K,
batch_size_out,
a_cols,
b_cols,
a_transposed,
b_transposed,
copies);
return;
}
mlx_matmul(
s,
d,
a,
b,
out,
M,
N,
K,
batch_size_out,
a_cols,
b_cols,
a_transposed,
b_transposed,
copies);
}
} // namespace mlx::core

View File

@@ -0,0 +1,29 @@
#include <algorithm>
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace mlx::core {
void mlx_matmul(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies);
} // namespace mlx::core

View File

@@ -0,0 +1,368 @@
#pragma once
#include <Metal/Metal.hpp>
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
namespace MTL::Private::Class {
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
_MTL_PRIVATE_DEF_CLS(MPSVector);
_MTL_PRIVATE_DEF_CLS(MPSKernel);
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
} // namespace MTL::Private::Class
namespace MTL::Private::Selector {
_MTL_PRIVATE_DEF_SEL(
matrixDescriptorWithRows_columns_rowBytes_dataType,
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(rows, "rows");
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
_MTL_PRIVATE_DEF_SEL(
initWithDevice_,
"initWithDevice:transposeLeft:transposeRight:"
"resultRows:resultColumns:interiorColumns:alpha:beta:");
_MTL_PRIVATE_DEF_SEL(
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
_MTL_PRIVATE_DEF_SEL(
vectorDescriptorWithLength_dataType,
"vectorDescriptorWithLength:dataType:");
_MTL_PRIVATE_DEF_SEL(
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(
initWithDevice_transpose_rows_columns_alpha_beta,
"initWithDevice:transpose:rows:columns:alpha:beta:");
_MTL_PRIVATE_DEF_SEL(
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
} // namespace MTL::Private::Selector
namespace MPS {
typedef enum DataType : uint32_t {
DataTypeFloatBit = 0x10000000,
DataTypeAlternateEncodingBit = 0x80000000,
DataTypeFloat16 = DataTypeFloatBit | 16,
DataTypeFloat32 = DataTypeFloatBit | 32,
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
} DataType;
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
public:
static class MatrixDescriptor* matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType);
static class MatrixDescriptor* matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType);
NS::UInteger rows() const;
};
class Matrix : public NS::Referencing<Matrix> {
public:
static class Matrix* alloc();
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
};
class Kernel : public NS::Referencing<Kernel> {
public:
NS::String* label() const;
MTL::Device* device() const;
};
class MatrixMultiplication
: public NS::Referencing<MatrixMultiplication, Kernel> {
public:
static class MatrixMultiplication* alloc();
MatrixMultiplication* init(
MTL::Device* device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta);
void encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* leftMatrix,
Matrix* rightMatrix,
Matrix* resultMatrix);
void setLeftMatrixOrigin(MTL::Origin origin);
void setRightMatrixOrigin(MTL::Origin origin);
void setResultMatrixOrigin(MTL::Origin origin);
void setBatchStart(NS::UInteger batchStart);
void setBatchSize(NS::UInteger batchSize);
};
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
public:
static class VectorDescriptor* vectorDescriptor(
NS::UInteger length,
NS::UInteger dataType);
static class VectorDescriptor* vectorDescriptor(
NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType);
};
class Vector : public NS::Referencing<Vector> {
public:
static class Vector* alloc();
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
};
class MatrixVectorMultiplication
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
public:
static class MatrixVectorMultiplication* alloc();
MatrixVectorMultiplication* init(
MTL::Device* device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta);
void encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* inputMatrix,
Vector* inputVector,
Vector* resultVector);
};
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType) {
return Object::sendMessage<MatrixDescriptor*>(
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
rows,
columns,
rowBytes,
dataType);
}
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType) {
return Object::sendMessage<MatrixDescriptor*>(
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
_MPS_PRIVATE_SEL(
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
rows,
columns,
matrices,
rowBytes,
matrixBytes,
dataType);
}
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
}
_MTL_INLINE Matrix* Matrix::alloc() {
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
}
_MTL_INLINE Matrix* Matrix::init(
MTL::Buffer* buffer,
MatrixDescriptor* descriptor) {
return Object::sendMessage<Matrix*>(
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
}
_MTL_INLINE Matrix* Matrix::init(
const MTL::Buffer* buffer,
MatrixDescriptor* descriptor) {
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
}
_MTL_INLINE NS::String* Kernel::label() const {
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
}
_MTL_INLINE MTL::Device* Kernel::device() const {
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
}
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
return NS::Object::alloc<MatrixMultiplication>(
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
}
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
MTL::Device* device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta) {
return Object::sendMessage<MatrixMultiplication*>(
this,
_MPS_PRIVATE_SEL(initWithDevice_),
device,
transposeLeft,
transposeRight,
resultRows,
resultColumns,
interiorColumns,
alpha,
beta);
}
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* leftMatrix,
Matrix* rightMatrix,
Matrix* resultMatrix) {
return Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
commandBuffer,
leftMatrix,
rightMatrix,
resultMatrix);
}
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
}
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
}
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
NS::UInteger length,
NS::UInteger dataType) {
return Object::sendMessage<VectorDescriptor*>(
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
length,
dataType);
}
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType) {
return Object::sendMessage<VectorDescriptor*>(
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
length,
vectors,
vectorBytes,
dataType);
}
_MTL_INLINE Vector* Vector::alloc() {
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
}
_MTL_INLINE Vector* Vector::init(
MTL::Buffer* buffer,
VectorDescriptor* descriptor) {
return Object::sendMessage<Vector*>(
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
}
_MTL_INLINE Vector* Vector::init(
const MTL::Buffer* buffer,
VectorDescriptor* descriptor) {
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
}
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
return NS::Object::alloc<MatrixVectorMultiplication>(
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
}
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
MTL::Device* device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta) {
return Object::sendMessage<MatrixVectorMultiplication*>(
this,
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
device,
transpose,
rows,
columns,
alpha,
beta);
}
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* inputMatrix,
Vector* inputVector,
Vector* resultVector) {
return Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
commandBuffer,
inputMatrix,
inputVector,
resultVector);
}
} // namespace MPS

View File

@@ -0,0 +1,604 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
void binary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::ostringstream kname;
switch (bopt) {
case ScalarScalar:
kname << "ss";
break;
case ScalarVector:
kname << "sv";
break;
case VectorScalar:
kname << "vs";
break;
case VectorVector:
kname << "vv";
break;
case General:
kname << "g";
break;
}
kname << op << type_to_name(a);
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 2);
if (bopt == General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
int rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = bopt == General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
void unary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
std::string tname = type_to_name(in);
std::string opt_name = contig ? "v" : "g";
auto kernel = d.get_kernel(opt_name + op + tname);
size_t nthreads = contig ? in.data_size() : in.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
if (!contig) {
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
compute_encoder->setBytes(
in.strides().data(), in.ndim() * sizeof(size_t), 3);
int ndim = in.ndim();
compute_encoder->setBytes(&ndim, sizeof(int), 4);
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "abs");
}
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "add");
}
template <typename T>
void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
enc->setBytes(&start, sizeof(T), 0);
T step = next - start;
enc->setBytes(&step, sizeof(T), 1);
}
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel("arange" + type_to_name(out));
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
switch (out.dtype()) {
case bool_: // unsupported
throw std::runtime_error("[Arange::eval_gpu] Does not support bool");
case uint8:
arange_set_scalars<uint8_t>(start_, start_ + step_, compute_encoder);
break;
case uint16:
arange_set_scalars<uint16_t>(start_, start_ + step_, compute_encoder);
break;
case uint32:
arange_set_scalars<uint32_t>(start_, start_ + step_, compute_encoder);
break;
case uint64:
arange_set_scalars<uint64_t>(start_, start_ + step_, compute_encoder);
break;
case int8:
arange_set_scalars<int8_t>(start_, start_ + step_, compute_encoder);
break;
case int16:
arange_set_scalars<int16_t>(start_, start_ + step_, compute_encoder);
break;
case int32:
arange_set_scalars<int32_t>(start_, start_ + step_, compute_encoder);
break;
case int64:
arange_set_scalars<int64_t>(start_, start_ + step_, compute_encoder);
break;
case float16:
arange_set_scalars<float16_t>(start_, start_ + step_, compute_encoder);
break;
case float32:
arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
break;
case bfloat16:
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16");
case complex64:
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
}
set_array_buffer(compute_encoder, out, 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccos");
}
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccosh");
}
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsin");
}
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsinh");
}
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctan");
}
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctanh");
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin_";
break;
case ArgReduce::ArgMax:
op_name = "argmax_";
break;
}
// Prepare the shapes, strides and axis arguments.
std::vector<size_t> in_strides = in.strides();
std::vector<int> shape = in.shape();
std::vector<size_t> out_strides = out.strides();
size_t axis_stride = in_strides[axis_];
size_t axis_size = shape[axis_];
if (out_strides.size() == in_strides.size()) {
out_strides.erase(out_strides.begin() + axis_);
}
in_strides.erase(in_strides.begin() + axis_);
shape.erase(shape.begin() + axis_);
size_t ndim = shape.size();
// ArgReduce
int simd_size = 32;
int n_reads = 4;
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name + type_to_name(in));
NS::UInteger thread_group_size = std::min(
(axis_size + n_reads - 1) / n_reads,
kernel->maxTotalThreadsPerThreadgroup());
// round up to the closest number divisible by simd_size
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
compute_encoder->setThreadgroupMemoryLength(
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis_));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cos");
}
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cosh");
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
}
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erf");
}
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erfinv");
}
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "exp");
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "ge");
}
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "geq");
}
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "le");
}
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq");
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:
unary_op(inputs, out, "log");
break;
case Base::two:
unary_op(inputs, out, "log2");
break;
case Base::ten:
unary_op(inputs, out, "log10");
break;
}
}
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "log1p");
}
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "lnot");
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lae");
}
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "max");
}
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
}
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "neg");
}
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "neq");
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy_gpu(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
}
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "pow");
}
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
size_t num_keys = keys.size() / 2;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
size_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
auto& s = stream();
auto& d = metal::device(s.device);
std::string kname = keys.flags().row_contiguous ? "rbitsc" : "rbits";
auto kernel = d.get_kernel(kname);
// organize into grid nkeys x elem_per_key
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto nthreads = std::min(num_keys * (half_size + odd), thread_group_size);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, keys, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&odd, sizeof(bool), 2);
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
if (!keys.flags().row_contiguous) {
int ndim = keys.ndim();
compute_encoder->setBytes(&ndim, sizeof(int), 4);
compute_encoder->setBytes(
keys.shape().data(), keys.ndim() * sizeof(int), 5);
compute_encoder->setBytes(
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (in.flags().row_contiguous) {
auto flags = in.flags();
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sigmoid");
}
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sign");
}
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sin");
}
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sinh");
}
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "square");
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
if (recip_) {
unary_op(inputs, out, "rsqrt");
} else {
unary_op(inputs, out, "sqrt");
}
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "sub");
}
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tan");
}
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tanh");
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
} // namespace mlx::core

130
mlx/backend/metal/scan.cpp Normal file
View File

@@ -0,0 +1,130 @@
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure contiguity
std::vector<array> copies;
auto in = inputs[0];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
}
std::ostringstream kname;
if (in.strides()[axis_] == 1) {
kname << "contiguous_scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
size_t size = in.shape(axis_);
compute_encoder->setBytes(&size, sizeof(size_t), 2);
// Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
int elements_per_simd = n_reads * 32;
int thread_groups = in.size() / size;
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (size < n_reads * 1024) {
thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) *
elements_per_simd;
} else if (size < n_reads * 2048) {
thread_group_size =
((size / 2 + elements_per_simd - 1) / elements_per_simd) *
elements_per_simd;
}
thread_group_size = std::min(
thread_group_size,
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
kname << "strided_scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
size_t size = in.shape(axis_);
size_t stride = in.strides()[axis_];
compute_encoder->setBytes(&size, sizeof(size_t), 2);
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
// Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
int tile_x = 32;
int tile_y = 32;
int elements_per_tile_x = tile_x * n_reads;
int grid_y = in.size() / size / stride;
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
if (copies.size() > 0) {
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
}
} // namespace mlx::core

167
mlx/backend/metal/utils.h Normal file
View File

@@ -0,0 +1,167 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core {
namespace {
void set_array_buffer(
MTL::ComputeCommandEncoder* compute_encoder,
MTL::ArgumentEncoder* enc,
const array& a,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
// MTL::Resource usage through argument buffer needs to be explicity
// flagged to enable hazard tracking
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
}
void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
}
std::string type_to_name(const array& a) {
std::string tname;
switch (a.dtype()) {
case bool_:
tname = "bool_";
break;
case uint8:
tname = "uint8";
break;
case uint16:
tname = "uint16";
break;
case uint32:
tname = "uint32";
break;
case uint64:
tname = "uint64";
break;
case int8:
tname = "int8";
break;
case int16:
tname = "int16";
break;
case int32:
tname = "int32";
break;
case int64:
tname = "int64";
break;
case float16:
tname = "float16";
break;
case float32:
tname = "float32";
break;
case bfloat16:
tname = "bfloat16";
break;
case complex64:
tname = "complex64";
break;
}
return tname;
}
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == 10) {
break;
}
}
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (xs[0].ndim() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < xs[0].ndim(); i++) {
bool contiguous = true;
for (auto& x : xs) {
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<size_t>> out_strides(xs.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = xs[0].shape()[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= xs[0].shape()[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < xs.size(); j++) {
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
template <typename... Arrays>
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) {
return collapse_contiguous_dims(
std::vector<array>{std::forward<Arrays>(xs)...});
}
} // namespace
} // namespace mlx::core