mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
awni's commit files
This commit is contained in:
26
mlx/backend/metal/CMakeLists.txt
Normal file
26
mlx/backend/metal/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||
|
||||
target_compile_definitions(
|
||||
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
|
113
mlx/backend/metal/copy.cpp
Normal file
113
mlx/backend/metal/copy.cpp
Normal file
@@ -0,0 +1,113 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(in, out);
|
||||
auto& strides_in = strides[0];
|
||||
auto& strides_out = strides[1];
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << "scopy";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << "vcopy";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "gcopy";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "ggcopy";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
size_t ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
|
||||
}
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(
|
||||
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
|
||||
}
|
||||
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
int rest = in.size() / (dim0 * dim1);
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
81
mlx/backend/metal/device.h
Normal file
81
mlx/backend/metal/device.h
Normal file
@@ -0,0 +1,81 @@
|
||||
#pragma once
|
||||
|
||||
#include <Metal/Metal.hpp>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
class Device {
|
||||
public:
|
||||
Device();
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
~Device();
|
||||
|
||||
MTL::Device* mtl_device() {
|
||||
return device_;
|
||||
};
|
||||
|
||||
void new_queue(int index);
|
||||
MTL::CommandBuffer* new_command_buffer(int index);
|
||||
MTL::CommandBuffer* get_command_buffer(int index);
|
||||
int get_command_buffer_ops(int index);
|
||||
void increment_command_buffer_ops(int index);
|
||||
void commit_command_buffer(int index);
|
||||
MTL::ComputeCommandEncoder* get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path);
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||
get_colocated_mtllib_path);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& name,
|
||||
const std::string& lib_name = "mlx");
|
||||
|
||||
MTL::ArgumentEncoder* argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||
|
||||
private:
|
||||
NS::AutoreleasePool* pool_;
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
std::mutex mtx_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
NS::AutoreleasePool*& thread_autorelease_pool();
|
||||
|
||||
} // namespace mlx::core::metal
|
296
mlx/backend/metal/indexing.cpp
Normal file
296
mlx/backend/metal/indexing.cpp
Normal file
@@ -0,0 +1,296 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
} // namespace
|
||||
|
||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
int nidx = inputs.size() - 1;
|
||||
|
||||
if (nidx > METAL_MAX_INDEX_ARRAYS) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Gather::eval_gpu] Gathering with more than "
|
||||
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
size_t nthreads = out.size();
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (size_of(out.dtype()) == 8) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Scatter::eval_gpu] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int nidx = axes_.size();
|
||||
if (nidx > METAL_MAX_INDEX_ARRAYS) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Scatter::eval_gpu] Gathering with more than "
|
||||
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Copy src into out
|
||||
auto copy_type =
|
||||
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
// Get stream
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Get kernel name
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
kname << "_none";
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
kname << "_sum";
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
kname << "_prod";
|
||||
break;
|
||||
case Scatter::Max:
|
||||
kname << "_max";
|
||||
break;
|
||||
case Scatter::Min:
|
||||
kname << "_min";
|
||||
break;
|
||||
}
|
||||
kname << "_" << nidx;
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
auto& upd = inputs.back();
|
||||
size_t nthreads = upd.size();
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
320
mlx/backend/metal/kernels/atomic.h
Normal file
320
mlx/backend/metal/kernels/atomic.h
Normal file
@@ -0,0 +1,320 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Atomic utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma METAL internals : enable
|
||||
template <typename T>
|
||||
constexpr constant bool is_metal_atomic = _disjunction<
|
||||
is_same<T, int>,
|
||||
is_same<T, uint>,
|
||||
is_same<T, ulong>,
|
||||
is_same<T, float>>::value;
|
||||
|
||||
#pragma METAL internals : disable
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct mlx_atomic {
|
||||
atomic<uint> val;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
|
||||
atomic<T> val;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Native metal atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC T
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
|
||||
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
T expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (!mlx_atomic_compare_exchange_weak_explicit(
|
||||
object, &expected, val * expected, offset)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
thread T* expected,
|
||||
T val,
|
||||
int offset) {
|
||||
return atomic_compare_exchange_weak_explicit(
|
||||
&(object[offset].val),
|
||||
expected,
|
||||
val,
|
||||
memory_order_relaxed,
|
||||
memory_order_relaxed);
|
||||
}
|
||||
|
||||
// Specialization for float since it does not atomic_fetch_min_explicit
|
||||
template <>
|
||||
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
|
||||
device mlx_atomic<float>* object,
|
||||
float val,
|
||||
int offset) {
|
||||
float expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (val < expected) {
|
||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||
object, &expected, val, offset)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Specialization for float since it does not atomic_fetch_max_explicit
|
||||
template <>
|
||||
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
|
||||
device mlx_atomic<float>* object,
|
||||
float val,
|
||||
int offset) {
|
||||
float expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (val > expected) {
|
||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||
object, &expected, val, offset)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Custom atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
|
||||
|
||||
template <typename T>
|
||||
union uint_or_packed {
|
||||
T val[packing_size<T>];
|
||||
uint bits;
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
struct mlx_atomic_update_helper {
|
||||
uint operator()(uint_or_packed<T> init, T update, int elem_offset) {
|
||||
Op op;
|
||||
init.val[elem_offset] = op(update, init.val[elem_offset]);
|
||||
return init.bits;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
METAL_FUNC void mlx_atomic_update_and_store(
|
||||
device mlx_atomic<T>* object,
|
||||
T update,
|
||||
int offset) {
|
||||
int pack_offset = offset / packing_size<T>;
|
||||
int elem_offset = offset % packing_size<T>;
|
||||
|
||||
mlx_atomic_update_helper<T, Op> helper;
|
||||
uint_or_packed<T> expected;
|
||||
expected.bits =
|
||||
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
||||
|
||||
while (Op::condition(update, expected.val[elem_offset]) &&
|
||||
!mlx_atomic_compare_exchange_weak_explicit(
|
||||
object,
|
||||
&(expected.bits),
|
||||
helper(expected, update, elem_offset),
|
||||
pack_offset)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct __None {
|
||||
static bool condition(T a, T b) {
|
||||
#pragma unused(a)
|
||||
#pragma unused(b)
|
||||
return true;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
#pragma unused(b)
|
||||
return a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct __Add {
|
||||
static bool condition(T a, T b) {
|
||||
#pragma unused(a)
|
||||
#pragma unused(b)
|
||||
return true;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct __Mul {
|
||||
static bool condition(T a, T b) {
|
||||
#pragma unused(a)
|
||||
return b != 0;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct __Max {
|
||||
static bool condition(T a, T b) {
|
||||
return a > b;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
return max(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct __Min {
|
||||
static bool condition(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
|
||||
T operator()(T a, T b) {
|
||||
return min(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC T
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
|
||||
int pack_offset = offset / sizeof(T);
|
||||
int elem_offset = offset % sizeof(T);
|
||||
uint_or_packed<T> packed_val;
|
||||
packed_val.bits =
|
||||
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
||||
return packed_val.val[elem_offset];
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
int pack_offset = offset / packing_size<T>;
|
||||
int elem_offset = offset % packing_size<T>;
|
||||
uint_or_packed<T> identity;
|
||||
identity.bits = __UINT32_MAX__;
|
||||
identity.val[elem_offset] = val;
|
||||
|
||||
atomic_fetch_and_explicit(
|
||||
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
int pack_offset = offset / packing_size<T>;
|
||||
int elem_offset = offset % packing_size<T>;
|
||||
uint_or_packed<T> identity;
|
||||
identity.bits = 0;
|
||||
identity.val[elem_offset] = val;
|
||||
|
||||
atomic_fetch_or_explicit(
|
||||
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
thread uint* expected,
|
||||
uint val,
|
||||
int offset) {
|
||||
return atomic_compare_exchange_weak_explicit(
|
||||
&(object[offset].val),
|
||||
expected,
|
||||
val,
|
||||
memory_order_relaxed,
|
||||
memory_order_relaxed);
|
||||
}
|
315
mlx/backend/metal/kernels/bf16.h
Normal file
315
mlx/backend/metal/kernels/bf16.h
Normal file
@@ -0,0 +1,315 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
#else
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
||||
// Check for nan
|
||||
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
||||
_fp_encoding_traits<float>::inf_mask) {
|
||||
return uint16_t(as_type<uint32_t>(0x7FC0));
|
||||
}
|
||||
// Take bits
|
||||
uint32_t float_bits = as_type<uint32_t>(x);
|
||||
|
||||
// Round to nearest even
|
||||
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
||||
|
||||
// Take upper 16 bits
|
||||
return float_bits >> 16;
|
||||
}
|
||||
|
||||
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||
return as_type<float>((uint32_t)x << 16);
|
||||
}
|
||||
|
||||
struct _MLX_BFloat16;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat struct
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct _MLX_BFloat16 {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Constructors
|
||||
uint16_t bits_;
|
||||
_MLX_BFloat16() thread = default;
|
||||
_MLX_BFloat16() threadgroup = default;
|
||||
_MLX_BFloat16() device = default;
|
||||
_MLX_BFloat16() constant = default;
|
||||
|
||||
struct bits_to_bfloat_struct {};
|
||||
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
||||
return bits_to_bfloat_struct();
|
||||
}
|
||||
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
||||
: bits_(bits) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions to bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions from bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const thread {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const threadgroup {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const device {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const constant {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat operators
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops
|
||||
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
||||
return -static_cast<float>(x);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Binary operators
|
||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
} \
|
||||
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Arithmetic Operators
|
||||
#define bfloat_binop(_op_, _operator_) \
|
||||
bfloat_binop_base( \
|
||||
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||
|
||||
bfloat_binop(+, operator+);
|
||||
bfloat_binop(-, operator-);
|
||||
bfloat_binop(*, operator*);
|
||||
bfloat_binop(/, operator/);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Comparison ops
|
||||
#define bfloat_compop(__op__, __operator__) \
|
||||
bfloat_binop_base( \
|
||||
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||
|
||||
bfloat_compop(>, operator>);
|
||||
bfloat_compop(<, operator<);
|
||||
bfloat_compop(>=, operator>=);
|
||||
bfloat_compop(<=, operator<=);
|
||||
bfloat_compop(==, operator==);
|
||||
bfloat_compop(!=, operator!=);
|
||||
|
||||
#undef bfloat_compop
|
||||
#undef bfloat_binop_base
|
||||
#undef bfloat_binop_helper
|
||||
#undef bfloat_binop
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Inplace Operators
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
} \
|
||||
constexpr METAL_FUNC addr_space itype& __operator__( \
|
||||
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
||||
|
||||
#define bfloat_inplace_op(itype) \
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
||||
|
||||
bfloat_inplace_op(float);
|
||||
bfloat_inplace_op(half);
|
||||
bfloat_inplace_op(int16_t);
|
||||
bfloat_inplace_op(int32_t);
|
||||
bfloat_inplace_op(int64_t);
|
||||
bfloat_inplace_op(uint16_t);
|
||||
bfloat_inplace_op(uint32_t);
|
||||
bfloat_inplace_op(uint64_t);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
#undef bfloat_inplace_op
|
||||
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
||||
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat typedef
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
typedef struct _MLX_BFloat16 bfloat16_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat numeric limits
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace metal {
|
||||
|
||||
template <>
|
||||
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
||||
static constexpr constant int digits = 8;
|
||||
static constexpr constant int digits10 = 2;
|
||||
static constexpr constant int max_digits10 = 4;
|
||||
static constexpr constant int radix = 2;
|
||||
static constexpr constant int min_exponent = -125;
|
||||
static constexpr constant int min_exponent10 = -37;
|
||||
static constexpr constant int max_exponent = 128;
|
||||
static constexpr constant int max_exponent10 = 38;
|
||||
|
||||
static constexpr bfloat16_t min() {
|
||||
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t lowest() {
|
||||
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t max() {
|
||||
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t epsilon() {
|
||||
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t round_error() {
|
||||
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t infinity() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t quiet_NaN() {
|
||||
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t signaling_NaN() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t denorm_min() {
|
||||
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
};
|
||||
|
||||
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
return x != x;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif // defined(__HAVE_BFLOAT__)
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
369
mlx/backend/metal/kernels/binary.metal
Normal file
369
mlx/backend/metal/kernels/binary.metal
Normal file
@@ -0,0 +1,369 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Add {
|
||||
template <typename T> T operator()(T x, T y) { return x + y; }
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T> bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real)
|
||||
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T> bool operator()(T x, T y) { return x > y; }
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x >= y; }
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T> bool operator()(T x, T y) { return x < y; }
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x <= y; }
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf) ? maxval :
|
||||
(maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x >= y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x <= y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T> T operator()(T x, T y) { return x * y; }
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x != y; }
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct Subtract {
|
||||
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_op_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, bool, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, bool, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, bool, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, bool, op) \
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_float(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
instantiate_binary_types_bool(ge, Greater)
|
||||
instantiate_binary_types_bool(geq, GreaterEqual)
|
||||
instantiate_binary_types_bool(le, Less)
|
||||
instantiate_binary_types_bool(leq, LessEqual)
|
||||
instantiate_binary_types_bool(neq, NotEqual)
|
||||
instantiate_binary_float(lae, LogAddExp)
|
||||
instantiate_binary_types(max, Maximum)
|
||||
instantiate_binary_types(min, Minimum)
|
||||
instantiate_binary_types(mul, Multiply)
|
||||
instantiate_binary_types(sub, Subtract)
|
||||
instantiate_binary_types(pow, Power)
|
||||
|
||||
// NaNEqual only needed for floating point types with boolean output
|
||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
110
mlx/backend/metal/kernels/complex.h
Normal file
110
mlx/backend/metal/kernels/complex.h
Normal file
@@ -0,0 +1,110 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct complex64_t;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_complex64 =
|
||||
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_complex64 =
|
||||
!is_same_v<T, complex64_t> &&
|
||||
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
||||
|
||||
struct complex64_t {
|
||||
float real;
|
||||
float imag;
|
||||
|
||||
// Constructors
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
||||
|
||||
// Conversions to complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||
|
||||
// Converstions from complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const thread {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const threadgroup {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const device {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const constant {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr complex64_t operator-(complex64_t x) {
|
||||
return {-x.real, -x.imag};
|
||||
}
|
||||
|
||||
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator>(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
||||
return operator>=(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator<(complex64_t a, complex64_t b) {
|
||||
return operator>(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
return a.real == b.real && a.imag == b.imag;
|
||||
}
|
||||
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
return {a.real - b.real, a.imag - b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
}
|
14
mlx/backend/metal/kernels/defines.h
Normal file
14
mlx/backend/metal/kernels/defines.h
Normal file
@@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __METAL__
|
||||
#define MTL_CONST constant
|
||||
#else
|
||||
#define MTL_CONST
|
||||
#endif
|
||||
|
||||
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
479
mlx/backend/metal/kernels/gemm/conv.h
Normal file
479
mlx/backend/metal/kernels/gemm/conv.h
Normal file
@@ -0,0 +1,479 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
int tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = BM;
|
||||
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>& params;
|
||||
|
||||
int weight_h;
|
||||
int weight_w;
|
||||
|
||||
int offsets_n[n_rows];
|
||||
int offsets_oh[n_rows];
|
||||
int offsets_ow[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const constant MLXConvParams<2>& params_,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bj),
|
||||
params(params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
int out_n_pixels = params.oS[0] * params.oS[1];
|
||||
|
||||
for (int i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = tid.y * BM + bi + i * bstride;
|
||||
offsets_n[i] = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
offsets_oh[i] = hw / params.oS[1];
|
||||
offsets_ow[i] = hw % params.oS[1];
|
||||
}
|
||||
|
||||
(void)lid;
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
|
||||
int n = offsets_n[i];
|
||||
int oh = offsets_oh[i];
|
||||
int ow = offsets_ow[i];
|
||||
|
||||
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
|
||||
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
|
||||
|
||||
// Read from input if in bounds
|
||||
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
||||
const device T* curr_src = src + n * params.in_strides[0] +
|
||||
ih * params.in_strides[1] + iw * params.in_strides[2];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = curr_src[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwize
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params.wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params.wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
int tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = BN;
|
||||
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>& params;
|
||||
|
||||
int weight_h;
|
||||
int weight_w;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const constant MLXConvParams<2>& params_,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_.wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj),
|
||||
params(params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
(void)lid;
|
||||
(void)tid;
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
const device T* curr_src =
|
||||
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params.wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params.wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int tgp_padding_a = 0,
|
||||
int tgp_padding_b = 0,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct Conv2DBlockMMA {
|
||||
// Warp tile size along M
|
||||
MLX_MTL_CONST int TM = BM / (WM * 8);
|
||||
// Warp tile size along N
|
||||
MLX_MTL_CONST int TN = BN / (WN * 8);
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TN_stride = 8 * WN;
|
||||
|
||||
// Leading dimensions of threadgroup A, B blocks
|
||||
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
MLX_MTL_CONST short simd_stride_a =
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
||||
MLX_MTL_CONST short simd_stride_b =
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
||||
|
||||
// Jump between elements
|
||||
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
||||
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
||||
|
||||
// Offsets within threadgroup
|
||||
const int tm;
|
||||
const int tn;
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DBlockMMA(
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Iterate over BK in blocks of 8
|
||||
#pragma clang loop unroll(full)
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
short2 offset_a =
|
||||
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
||||
short2 offset_b =
|
||||
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
||||
|
||||
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
||||
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
||||
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
||||
As__ += simd_stride_a;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
||||
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
||||
Bs__ += simd_stride_b;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
}
|
||||
|
||||
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct Conv2DImplicitGEMMKernel {
|
||||
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
MLX_MTL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
||||
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
||||
|
||||
using loader_a_t =
|
||||
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
|
||||
using loader_b_t =
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
|
||||
using mma_t = Conv2DBlockMMA<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
tgp_padding_a,
|
||||
tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const int K = params.wt_strides[0];
|
||||
const int N = params.O;
|
||||
|
||||
B += c_col * K;
|
||||
C += c_row * N + c_col;
|
||||
|
||||
// Prepare threadgroup memory for loading
|
||||
threadgroup T* As = tgp_memory;
|
||||
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
}
|
||||
};
|
536
mlx/backend/metal/kernels/gemm/gemm.h
Normal file
536
mlx/backend/metal/kernels/gemm/gemm.h
Normal file
@@ -0,0 +1,536 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BROWS,
|
||||
int BCOLS,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
bool transpose,
|
||||
bool ldK,
|
||||
int tgp_padding = 0>
|
||||
struct BlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
|
||||
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
// Stride along reduction axis between blocks
|
||||
const int tstride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoader(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tstride(
|
||||
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
|
||||
|
||||
// Iterate over rows of block
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
// Row is in bounds, we check against column
|
||||
if ((bi + i) < src_tile_dim.y) {
|
||||
// Use fast thread memory for bound checks
|
||||
short tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
||||
}
|
||||
|
||||
// Read all valid indcies into tmp_val
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Row is out of bounds, we just fill tgp memory with zeros
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tstride;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int tgp_padding_a = 0,
|
||||
int tgp_padding_b = 0,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct BlockMMA {
|
||||
// Warp tile size along M
|
||||
MLX_MTL_CONST int TM = BM / (WM * 8);
|
||||
// Warp tile size along N
|
||||
MLX_MTL_CONST int TN = BN / (WN * 8);
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TN_stride = 8 * WN;
|
||||
|
||||
// Leading dimensions of threadgroup A, B blocks
|
||||
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
MLX_MTL_CONST short simd_stride_a =
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
||||
MLX_MTL_CONST short simd_stride_b =
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
||||
|
||||
// Jump between elements
|
||||
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
||||
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
||||
|
||||
// Offsets within threadgroup
|
||||
const int tm;
|
||||
const int tn;
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Iterate over BK in blocks of 8
|
||||
#pragma clang loop unroll(full)
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
short2 offset_a =
|
||||
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
||||
short2 offset_b =
|
||||
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
||||
|
||||
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
||||
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
||||
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
||||
As__ += simd_stride_a;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
||||
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
||||
Bs__ += simd_stride_b;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
}
|
||||
|
||||
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct GEMMKernel {
|
||||
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
MLX_MTL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
||||
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
||||
|
||||
using loader_a_t = BlockLoader<
|
||||
T,
|
||||
BM,
|
||||
BK,
|
||||
BK,
|
||||
vec_size,
|
||||
tgp_size,
|
||||
transpose_a,
|
||||
true,
|
||||
tgp_padding_a>;
|
||||
using loader_b_t = BlockLoader<
|
||||
T,
|
||||
BK,
|
||||
BN,
|
||||
BK,
|
||||
vec_size,
|
||||
tgp_size,
|
||||
transpose_b,
|
||||
false,
|
||||
tgp_padding_b>;
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
tgp_padding_a,
|
||||
tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant int& M [[buffer(3)]],
|
||||
const constant int& N [[buffer(4)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& batch_stride_a [[buffer(6)]],
|
||||
const constant int& batch_stride_b [[buffer(7)]],
|
||||
const constant int& batch_stride_c [[buffer(8)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
// Adjust for batch
|
||||
A += batch_stride_a * tid.z;
|
||||
B += batch_stride_b * tid.z;
|
||||
C += batch_stride_c * tid.z;
|
||||
|
||||
// Adjust for transpose
|
||||
const int lda_dev = transpose_a ? M : K;
|
||||
const int ldb_dev = transpose_b ? K : N;
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * K;
|
||||
B += transpose_b ? c_col * K : c_col;
|
||||
C += c_row * N + c_col;
|
||||
|
||||
// Prepare threadgroup memory for loading
|
||||
threadgroup T* As = tgp_memory;
|
||||
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
|
||||
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned && K_aligned) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN aligned, K unaligned loop
|
||||
else if (MN_aligned && !K_aligned) {
|
||||
// Main loop
|
||||
int k = 0;
|
||||
for (; k + BK <= K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
// Loop tail
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
loader_a.load_safe(short2(K - k, BM));
|
||||
loader_b.load_safe(short2(BN, K - k));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
|
||||
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
|
||||
|
||||
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
|
||||
int k = 0;
|
||||
for (; k + BK <= K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (k < K) {
|
||||
loader_a.load_safe(short2(K - k, BM));
|
||||
loader_b.load_safe(short2(BN, K - k));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
mma_op.store_result(C, N);
|
||||
return;
|
||||
|
||||
} else {
|
||||
int k = 0;
|
||||
for (; k + BK <= K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_safe(short2(BK, src_tile_dims.y));
|
||||
loader_b.load_safe(short2(src_tile_dims.x, BK));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (k < K) {
|
||||
loader_a.load_safe(short2(K - k, src_tile_dims.y));
|
||||
loader_b.load_safe(short2(src_tile_dims.x, K - k));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
mma_op.store_result_safe(C, N, src_tile_dims);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
302
mlx/backend/metal/kernels/gemv.metal
Normal file
302
mlx/backend/metal/kernels/gemv.metal
Normal file
@@ -0,0 +1,302 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static constant constexpr int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel]] void gemv(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
// Threadgroup in_vec cache
|
||||
threadgroup T in_vec_block[BN][TN * 2];
|
||||
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
thread T inter[TN];
|
||||
thread T v_coeff[TN];
|
||||
|
||||
// Block position
|
||||
int out_row = (tid.x * BM + simd_gid) * TM;
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if(out_row >= out_vec_size)
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * in_vec_size;
|
||||
|
||||
// Loop over in_vec in blocks of BN * TN
|
||||
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Prefetch in_vector for threadgroup use
|
||||
if(simd_gid == 0) {
|
||||
// Main load loop
|
||||
if(bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[simd_lid][tn] = in_vec[bn + tn];
|
||||
}
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[simd_lid][tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load for all rows
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] = in_vec_block[simd_lid][tn];
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
result[tm] = simd_sum(result[tm]);
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if(simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||
|
||||
instantiate_gemv_blocks(float32, float)
|
||||
instantiate_gemv_blocks(float16, half)
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel]] void gemv_t(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
// Thread local accumulation results
|
||||
T result[TN] = {0};
|
||||
T inter[TN];
|
||||
T v_coeff[TM];
|
||||
|
||||
// Threadgroup accumulation results
|
||||
threadgroup T tgp_results[BN][BM][TM];
|
||||
|
||||
int out_col = (tid.x * BN + lid.x) * TN;
|
||||
int in_row = lid.y * TM;
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
// Edgecase handling
|
||||
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||
|
||||
// Per thread accumulation main loop
|
||||
int bm = in_row;
|
||||
for(; bm < in_vec_size; bm += BM * TM) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if(bm + TM <= in_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
|
||||
} else { // Edgecase handling
|
||||
for(int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Threadgroup collection
|
||||
for(int i = 0; i < TN; i++) {
|
||||
tgp_results[lid.x][lid.y][i] = result[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(lid.y == 0 && out_col < out_vec_size) {
|
||||
// Threadgroup accumulation
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < BM; i++) {
|
||||
for(int j = 0; j < TN; j++) {
|
||||
result[j] += tgp_results[lid.x][i][j];
|
||||
}
|
||||
}
|
||||
|
||||
for(int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
|
||||
|
||||
instantiate_gemv_t_blocks(float32, float)
|
||||
instantiate_gemv_t_blocks(float16, half)
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t)
|
226
mlx/backend/metal/kernels/softmax.metal
Normal file
226
mlx/backend/metal/kernels/softmax.metal
Normal file
@@ -0,0 +1,226 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause it is gonna be x
|
||||
// will be in (-oo, 0] anyway and subsequently it will be divided by
|
||||
// sum(exp(x_i)).
|
||||
return fast::exp(x);
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_single_row(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
T ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
ld[i] = in[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<T>::finite_min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
T maxval = Limits<T>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
maxval = simd_max(maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[0] = maxval;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
T normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T exp_x = softmax_exp(ld[i] - maxval);
|
||||
ld[i] = exp_x;
|
||||
normalizer += exp_x;
|
||||
}
|
||||
normalizer = simd_sum(normalizer);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[0] = normalizer;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = 1 / local_normalizer[0];
|
||||
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
T prevmax;
|
||||
T maxval = Limits<T>::finite_min;
|
||||
T normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
T vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[offset + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] =
|
||||
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
||||
}
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
|
||||
// lsize) parts. We need to combine them.
|
||||
// 1. We start by finding the max across simd groups
|
||||
// 2. We then change the partial normalizers to account for a possible
|
||||
// change in max
|
||||
// 3. We sum all normalizers
|
||||
prevmax = maxval;
|
||||
maxval = simd_max(maxval);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
normalizer = simd_sum(normalizer);
|
||||
|
||||
// Now the normalizer and max value is correct for each simdgroup. We write
|
||||
// them shared memory and combine them.
|
||||
prevmax = maxval;
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Finally given the normalizer and max value we can directly write the
|
||||
// softmax output
|
||||
out += gid * axis_size;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (offset + i < axis_size) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_softmax_single_row(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax_looped(name, itype) \
|
||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
instantiate_softmax_single_row(name, itype) \
|
||||
instantiate_softmax_looped(name, itype)
|
||||
|
||||
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
818
mlx/backend/metal/kernels/sort.metal
Normal file
818
mlx/backend/metal/kernels/sort.metal
Normal file
@@ -0,0 +1,818 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
using namespace metal;\
|
||||
|
||||
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Thread-level sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
|
||||
T w = a;
|
||||
a = b;
|
||||
b = w;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct LessThan {
|
||||
static constexpr constant T init = Limits<T>::max;
|
||||
|
||||
METAL_FUNC bool operator()(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp>
|
||||
struct ThreadSort {
|
||||
static METAL_FUNC void sort(
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
|
||||
CompareOp op;
|
||||
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for(short i = 0; i < N_PER_THREAD; ++i) {
|
||||
MLX_MTL_LOOP_UNROLL
|
||||
for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||
if(op(vals[j + 1], vals[j])) {
|
||||
thread_swap(vals[j + 1], vals[j]);
|
||||
thread_swap(idxs[j + 1], idxs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Threadgroup-level sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp>
|
||||
struct BlockMergeSort {
|
||||
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||
static METAL_FUNC int merge_partition(
|
||||
const threadgroup val_t* As,
|
||||
const threadgroup val_t* Bs,
|
||||
short A_sz,
|
||||
short B_sz,
|
||||
short sort_md) {
|
||||
|
||||
CompareOp op;
|
||||
|
||||
short A_st = max(0, sort_md - B_sz);
|
||||
short A_ed = min(sort_md, A_sz);
|
||||
|
||||
while(A_st < A_ed) {
|
||||
short md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if(op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
|
||||
}
|
||||
|
||||
static METAL_FUNC void merge_step(
|
||||
const threadgroup val_t* As,
|
||||
const threadgroup val_t* Bs,
|
||||
const threadgroup idx_t* As_idx,
|
||||
const threadgroup idx_t* Bs_idx,
|
||||
short A_sz,
|
||||
short B_sz,
|
||||
thread val_t (&vals)[N_PER_THREAD],
|
||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||
|
||||
CompareOp op;
|
||||
short a_idx = 0;
|
||||
short b_idx = 0;
|
||||
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
auto a = As[a_idx];
|
||||
auto b = Bs[b_idx];
|
||||
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||
|
||||
vals[i] = pred ? b : a;
|
||||
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
||||
|
||||
b_idx += short(pred);
|
||||
a_idx += short(!pred);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static METAL_FUNC void sort(
|
||||
threadgroup val_t* tgp_vals [[threadgroup(0)]],
|
||||
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
||||
int size_sorted_axis,
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// Get thread location
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
|
||||
// Load from shared memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
thread_vals[i] = tgp_vals[idx + i];
|
||||
if(ARG_SORT) {
|
||||
thread_idxs[i] = tgp_idxs[idx + i];
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread sort
|
||||
if(idx < size_sorted_axis) {
|
||||
thread_sort_t::sort(thread_vals, thread_idxs);
|
||||
}
|
||||
|
||||
// Do merges using threadgroup memory
|
||||
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
|
||||
// Update threadgroup memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if(ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_threads;
|
||||
int merge_lane = lid.x % merge_threads;
|
||||
|
||||
int sort_sz = N_PER_THREAD * merge_threads;
|
||||
int sort_st = N_PER_THREAD * merge_threads * merge_group;
|
||||
|
||||
// As = tgp_vals[A_st:A_ed] is sorted
|
||||
// Bs = tgp_vals[B_st:B_ed] is sorted
|
||||
int A_st = sort_st;
|
||||
int A_ed = sort_st + sort_sz / 2;
|
||||
int B_st = sort_st + sort_sz / 2;
|
||||
int B_ed = sort_st + sort_sz;
|
||||
|
||||
const threadgroup val_t* As = tgp_vals + A_st;
|
||||
const threadgroup val_t* Bs = tgp_vals + B_st;
|
||||
int A_sz = A_ed - A_st;
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Find a partition of merge elements
|
||||
// Ci = merge(As[partition:], Bs[sort_md - partition:])
|
||||
// of size N_PER_THREAD for each merge lane i
|
||||
// C = [Ci] is sorted
|
||||
int sort_md = N_PER_THREAD * merge_lane;
|
||||
int partition = merge_partition(
|
||||
As,
|
||||
Bs,
|
||||
A_sz,
|
||||
B_sz,
|
||||
sort_md);
|
||||
|
||||
As += partition;
|
||||
Bs += sort_md - partition;
|
||||
|
||||
A_sz -= partition;
|
||||
B_sz -= sort_md - partition;
|
||||
|
||||
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||
|
||||
// Merge starting at the partition and store results in thread registers
|
||||
merge_step(
|
||||
As,
|
||||
Bs,
|
||||
As_idx,
|
||||
Bs_idx,
|
||||
A_sz,
|
||||
B_sz,
|
||||
thread_vals,
|
||||
thread_idxs);
|
||||
|
||||
}
|
||||
|
||||
// Write out to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
if(ARG_SORT) {
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Kernel sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<T>>
|
||||
struct KernelMergeSort {
|
||||
using val_t = T;
|
||||
using idx_t = uint;
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
const device T* inp,
|
||||
device U* out,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
const constant int& stride_segment_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// tid.y tells us the segment index
|
||||
inp += tid.y * stride_segment_axis;
|
||||
out += tid.y * stride_segment_axis;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
|
||||
if(ARG_SORT) {
|
||||
tgp_idxs[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Sort elements within the block
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) {
|
||||
if(ARG_SORT) {
|
||||
out[i * stride_sorted_axis] = tgp_idxs[i];
|
||||
} else {
|
||||
out[i * stride_sorted_axis] = tgp_vals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
if(ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
constant constexpr const int zero_helper = 0;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out += block_idx;
|
||||
|
||||
if(ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& nc_dim [[buffer(4)]], \
|
||||
const device int* nc_shape [[buffer(5)]], \
|
||||
const device size_t* nc_strides [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||
|
||||
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
||||
|
||||
#define instantiate_block_sort_tn(itname, itype, bn) \
|
||||
instantiate_block_sort_base(itname, itype, bn, 8) \
|
||||
instantiate_arg_block_sort_base(itname, itype, bn, 8)
|
||||
|
||||
#define instantiate_block_sort_bn(itname, itype) \
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 256) \
|
||||
instantiate_block_sort_tn(itname, itype, 512)
|
||||
|
||||
instantiate_block_sort_bn(uint8, uint8_t)
|
||||
instantiate_block_sort_bn(uint16, uint16_t)
|
||||
instantiate_block_sort_bn(uint32, uint32_t)
|
||||
instantiate_block_sort_bn(int8, int8_t)
|
||||
instantiate_block_sort_bn(int16, int16_t)
|
||||
instantiate_block_sort_bn(int32, int32_t)
|
||||
instantiate_block_sort_bn(float16, half)
|
||||
instantiate_block_sort_bn(float32, float)
|
||||
instantiate_block_sort_bn(bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_block_sort_long(itname, itype) \
|
||||
instantiate_block_sort_tn(itname, itype, 128) \
|
||||
instantiate_block_sort_tn(itname, itype, 256)
|
||||
|
||||
instantiate_block_sort_long(uint64, uint64_t)
|
||||
instantiate_block_sort_long(int64, int64_t)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multi block merge sort
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
struct KernelMultiBlockMergeSort {
|
||||
using block_merge_sort_t = BlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||
|
||||
static METAL_FUNC void block_sort(
|
||||
const device val_t* inp,
|
||||
device val_t* out_vals,
|
||||
device idx_t* out_idxs,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// tid.y tells us the segment index
|
||||
int base_idx = tid.x * N_PER_BLOCK;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
|
||||
tgp_idxs[i] = idx;
|
||||
}
|
||||
|
||||
// Sort elements within the block
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write output
|
||||
for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if(idx < size_sorted_axis) {
|
||||
out_vals[idx] = tgp_vals[i];
|
||||
out_idxs[idx] = tgp_idxs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC int merge_partition(
|
||||
const device val_t* As,
|
||||
const device val_t* Bs,
|
||||
int A_sz,
|
||||
int B_sz,
|
||||
int sort_md) {
|
||||
|
||||
CompareOp op;
|
||||
|
||||
int A_st = max(0, sort_md - B_sz);
|
||||
int A_ed = min(sort_md, A_sz);
|
||||
|
||||
while(A_st < A_ed) {
|
||||
int md = A_st + (A_ed - A_st) / 2;
|
||||
auto a = As[md];
|
||||
auto b = Bs[sort_md - 1 - md];
|
||||
|
||||
if(op(b, a)) {
|
||||
A_ed = md;
|
||||
} else {
|
||||
A_st = md + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return A_ed;
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
|
||||
const device val_t* inp [[buffer(0)]],
|
||||
device val_t* out_vals [[buffer(1)]],
|
||||
device idx_t* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out_vals += tid.y * size_sorted_axis;
|
||||
out_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
|
||||
sort_kernel::block_sort(
|
||||
inp,
|
||||
out_vals,
|
||||
out_idxs,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD>;
|
||||
|
||||
block_partitions += tid.y * tgp_dims.x;
|
||||
dev_vals += tid.y * size_sorted_axis;
|
||||
dev_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_tiles;
|
||||
int merge_lane = lid.x % merge_tiles;
|
||||
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
|
||||
int A_st = min(size_sorted_axis, sort_st);
|
||||
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
int B_st = A_ed;
|
||||
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||
|
||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||
int partition = sort_kernel::merge_partition(
|
||||
dev_vals + A_st,
|
||||
dev_vals + B_st,
|
||||
A_ed - A_st,
|
||||
B_ed - B_st,
|
||||
partition_at);
|
||||
|
||||
block_partitions[lid.x] = A_st + partition;
|
||||
|
||||
}
|
||||
|
||||
template <
|
||||
typename val_t,
|
||||
typename idx_t,
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD,
|
||||
typename CompareOp = LessThan<val_t>>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
|
||||
const device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals_in [[buffer(1)]],
|
||||
const device idx_t* dev_idxs_in [[buffer(2)]],
|
||||
device val_t* dev_vals_out [[buffer(3)]],
|
||||
device idx_t* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
val_t,
|
||||
idx_t,
|
||||
ARG_SORT,
|
||||
BLOCK_THREADS,
|
||||
N_PER_THREAD,
|
||||
CompareOp>;
|
||||
|
||||
using block_sort_t = typename sort_kernel::block_merge_sort_t;
|
||||
|
||||
block_partitions += tid.y * (num_tiles + 1);
|
||||
dev_vals_in += tid.y * size_sorted_axis;
|
||||
dev_idxs_in += tid.y * size_sorted_axis;
|
||||
dev_vals_out += tid.y * size_sorted_axis;
|
||||
dev_idxs_out += tid.y * size_sorted_axis;
|
||||
|
||||
int block_idx = tid.x;
|
||||
int merge_group = block_idx / merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
|
||||
|
||||
int A_st = block_partitions[block_idx + 0];
|
||||
int A_ed = block_partitions[block_idx + 1];
|
||||
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st);
|
||||
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||
|
||||
if((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||
A_ed = min(size_sorted_axis, sort_st + sort_sz/2);
|
||||
B_ed = min(size_sorted_axis, sort_st + sort_sz);
|
||||
}
|
||||
|
||||
int A_sz = A_ed - A_st;
|
||||
int B_sz = B_ed - B_st;
|
||||
|
||||
// Load from global memory
|
||||
thread val_t thread_vals[N_PER_THREAD];
|
||||
thread idx_t thread_idxs[N_PER_THREAD];
|
||||
for(int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
if(idx < (A_sz + B_sz)) {
|
||||
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
|
||||
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
|
||||
} else {
|
||||
thread_vals[i] = CompareOp::init;
|
||||
thread_idxs[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Write to shared memory
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; i++) {
|
||||
int idx = BLOCK_THREADS * i + lid.x;
|
||||
tgp_vals[idx] = thread_vals[i];
|
||||
tgp_idxs[idx] = thread_idxs[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Merge
|
||||
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
||||
|
||||
int A_st_local = block_sort_t::merge_partition(
|
||||
tgp_vals,
|
||||
tgp_vals + A_sz,
|
||||
A_sz,
|
||||
B_sz,
|
||||
sort_md_local);
|
||||
int A_ed_local = A_sz;
|
||||
|
||||
int B_st_local = sort_md_local - A_st_local;
|
||||
int B_ed_local = B_sz;
|
||||
|
||||
int A_sz_local = A_ed_local - A_st_local;
|
||||
int B_sz_local = B_ed_local - B_st_local;
|
||||
|
||||
// Do merge
|
||||
block_sort_t::merge_step(
|
||||
tgp_vals + A_st_local,
|
||||
tgp_vals + A_ed_local + B_st_local,
|
||||
tgp_idxs + A_st_local,
|
||||
tgp_idxs + A_ed_local + B_st_local,
|
||||
A_sz_local,
|
||||
B_sz_local,
|
||||
thread_vals,
|
||||
thread_idxs);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||
int idx = lid.x * N_PER_THREAD;
|
||||
tgp_vals[idx + i] = thread_vals[i];
|
||||
tgp_idxs[idx + i] = thread_idxs[i];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Write output
|
||||
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
|
||||
for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||
int idx = base_idx + i;
|
||||
if(idx < size_sorted_axis) {
|
||||
dev_vals_out[idx] = tgp_vals[i];
|
||||
dev_idxs_out[idx] = tgp_idxs[i];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device vtype* inp [[buffer(0)]], \
|
||||
device vtype* out_vals [[buffer(1)]], \
|
||||
device itype* out_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& merge_tiles [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals_in [[buffer(1)]], \
|
||||
const device itype* dev_idxs_in [[buffer(2)]], \
|
||||
device vtype* dev_vals_out [[buffer(3)]], \
|
||||
device itype* dev_idxs_out [[buffer(4)]], \
|
||||
const constant int& size_sorted_axis [[buffer(5)]], \
|
||||
const constant int& merge_tiles [[buffer(6)]], \
|
||||
const constant int& num_tiles [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||
|
||||
instantiate_multi_block_sort_base(uint8, uint8_t)
|
||||
instantiate_multi_block_sort_base(uint16, uint16_t)
|
||||
instantiate_multi_block_sort_base(uint32, uint32_t)
|
||||
instantiate_multi_block_sort_base(int8, int8_t)
|
||||
instantiate_multi_block_sort_base(int16, int16_t)
|
||||
instantiate_multi_block_sort_base(int32, int32_t)
|
||||
instantiate_multi_block_sort_base(float16, half)
|
||||
instantiate_multi_block_sort_base(float32, float)
|
||||
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
||||
|
||||
instantiate_multi_block_sort_long(uint64, uint64_t)
|
||||
instantiate_multi_block_sort_long(int64, int64_t)
|
244
mlx/backend/metal/kernels/utils.h
Normal file
244
mlx/backend/metal/kernels/utils.h
Normal file
@@ -0,0 +1,244 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/complex.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const constant U max;
|
||||
static const constant U min;
|
||||
static const constant U finite_max;
|
||||
static const constant U finite_min;
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
metal::numeric_limits<type>::min(); \
|
||||
};
|
||||
|
||||
instantiate_default_limit(uint8_t);
|
||||
instantiate_default_limit(uint16_t);
|
||||
instantiate_default_limit(uint32_t);
|
||||
instantiate_default_limit(uint64_t);
|
||||
instantiate_default_limit(int8_t);
|
||||
instantiate_default_limit(int16_t);
|
||||
instantiate_default_limit(int32_t);
|
||||
instantiate_default_limit(int64_t);
|
||||
|
||||
#define instantiate_float_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr constant type max = \
|
||||
metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type min = \
|
||||
-metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
-metal::numeric_limits<type>::max(); \
|
||||
};
|
||||
|
||||
instantiate_float_limit(half);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr constant bool max = true;
|
||||
static constexpr constant bool min = false;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline size_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t strides[NDIM]) {
|
||||
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
||||
return elem * stride;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
inline size_t elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint elem_to_loc_nd(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides);
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<1>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
return (elem % shape[0]) * strides[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<2>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<3>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<4>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[3]) * strides[3];
|
||||
elem /= shape[3];
|
||||
loc += (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** Compute ceil((float)N/(float)M) */
|
||||
inline size_t ceildiv(size_t N, size_t M) {
|
||||
return (N + M - 1) / M;
|
||||
}
|
||||
|
||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||
inline float log1p(float x) {
|
||||
float xp1 = 1.0f + x;
|
||||
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||
}
|
||||
|
||||
inline bfloat16_t log1p(bfloat16_t x) {
|
||||
float xp1 = 1.0f + static_cast<float>(x);
|
||||
bfloat16_t ret =
|
||||
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
return ret;
|
||||
}
|
446
mlx/backend/metal/matmul.cpp
Normal file
446
mlx/backend/metal/matmul.cpp
Normal file
@@ -0,0 +1,446 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
bool use_mps() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_USE_MPS")) {
|
||||
return std::string(buff_str) != "OFF";
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
static bool use_mps_ = get_val();
|
||||
return use_mps_;
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
|
||||
inline void mps_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
|
||||
|
||||
if (out.dtype() == float16) {
|
||||
mps_dtype = MPS::DataTypeFloat16;
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
mps_dtype = MPS::DataTypeBFloat16;
|
||||
}
|
||||
|
||||
// Used batched MPSMatrixMultiplication if batch_size_out > 1
|
||||
// We only accept the following cases:
|
||||
// 1. Both a, b have batch_size_out matrices worth of data
|
||||
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
||||
// the other has matrix worth of data
|
||||
|
||||
// The matrix dimsenisons of a and b are sure to be regularly strided
|
||||
if (batch_size_out > 1) {
|
||||
// No broadcasting defaults
|
||||
auto batch_size_a = a.data_size() / (M * K);
|
||||
auto batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
auto matrix_stride_a = M * K;
|
||||
auto matrix_stride_b = K * N;
|
||||
auto matrix_stride_out = M * N;
|
||||
|
||||
// At this point, batch_size_a, batch_size_b show the number of matrices
|
||||
// in data, no broadcasted strides considered
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b)) {
|
||||
// Handle simple broadcasting
|
||||
if (std::min(batch_size_a, batch_size_b) == 1) {
|
||||
matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a;
|
||||
matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b;
|
||||
|
||||
batch_size_a = batch_size_out;
|
||||
batch_size_b = batch_size_out;
|
||||
}
|
||||
|
||||
// Only proceed if broadcasting between a and b is simple
|
||||
// At this point, batch_size_a, batch_size_b show the number of matrices
|
||||
// after broadcasting
|
||||
if (batch_size_a == batch_size_b) {
|
||||
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
(M * K) / lda,
|
||||
lda,
|
||||
batch_size_a,
|
||||
lda * a.itemsize(),
|
||||
(matrix_stride_a * a.itemsize()),
|
||||
mps_dtype);
|
||||
|
||||
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
(K * N) / ldb,
|
||||
ldb,
|
||||
batch_size_b,
|
||||
ldb * b.itemsize(),
|
||||
(matrix_stride_b * b.itemsize()),
|
||||
mps_dtype);
|
||||
|
||||
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
M,
|
||||
N,
|
||||
batch_size_out,
|
||||
N * out.itemsize(),
|
||||
matrix_stride_out * out.itemsize(),
|
||||
mps_dtype);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
||||
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
||||
|
||||
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||
|
||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
||||
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
kernel->setBatchSize(batch_size_out);
|
||||
kernel->setBatchStart(0);
|
||||
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
||||
command_buffer->addCompletedHandler(
|
||||
[a_mat, b_mat, out_mat, kernel, copies](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
a_mat->release();
|
||||
b_mat->release();
|
||||
out_mat->release();
|
||||
kernel->release();
|
||||
copies.clear();
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule as many calls to MPSMatrixMultiplication as needed otherwise
|
||||
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype);
|
||||
|
||||
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype);
|
||||
|
||||
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
batch_size_out * M, N, N * out.itemsize(), mps_dtype);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
||||
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
||||
|
||||
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||
|
||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
||||
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda;
|
||||
auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb;
|
||||
kernel->setLeftMatrixOrigin({a_row, 0, 0});
|
||||
kernel->setRightMatrixOrigin({b_row, 0, 0});
|
||||
kernel->setResultMatrixOrigin({i * static_cast<size_t>(M), 0, 0});
|
||||
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
||||
}
|
||||
|
||||
command_buffer->addCompletedHandler(
|
||||
[a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable {
|
||||
a_mat->release();
|
||||
b_mat->release();
|
||||
out_mat->release();
|
||||
kernel->release();
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlx_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
// Account for batch sizes and basic broadcasting
|
||||
int batch_size_a = a.data_size() / (M * K);
|
||||
int batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
|
||||
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 2ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||
(batch_size_a == batch_size_b ||
|
||||
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims =
|
||||
MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(&M, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&K, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
} else { // Other launch kernels with set offsets
|
||||
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||
|
||||
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
|
||||
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
|
||||
|
||||
compute_encoder->setBytes(&M, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&K, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, a_cols, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(b_pre);
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
|
||||
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
|
||||
int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0;
|
||||
|
||||
int batch_size_vec = vec.data_size() / in_vector_len;
|
||||
int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
bn = 16;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
d.end_encoding(s.index);
|
||||
|
||||
if (use_mps()) {
|
||||
mps_matmul(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_size_out,
|
||||
a_cols,
|
||||
b_cols,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
copies);
|
||||
return;
|
||||
}
|
||||
|
||||
mlx_matmul(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_size_out,
|
||||
a_cols,
|
||||
b_cols,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
copies);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
29
mlx/backend/metal/matmul.h
Normal file
29
mlx/backend/metal/matmul.h
Normal file
@@ -0,0 +1,29 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void mlx_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies);
|
||||
|
||||
} // namespace mlx::core
|
368
mlx/backend/metal/mps/gemm.h
Normal file
368
mlx/backend/metal/mps/gemm.h
Normal file
@@ -0,0 +1,368 @@
|
||||
#pragma once
|
||||
|
||||
#include <Metal/Metal.hpp>
|
||||
|
||||
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
|
||||
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
|
||||
|
||||
namespace MTL::Private::Class {
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSVector);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSKernel);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
|
||||
} // namespace MTL::Private::Class
|
||||
|
||||
namespace MTL::Private::Selector {
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
matrixDescriptorWithRows_columns_rowBytes_dataType,
|
||||
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
|
||||
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(rows, "rows");
|
||||
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
initWithDevice_,
|
||||
"initWithDevice:transposeLeft:transposeRight:"
|
||||
"resultRows:resultColumns:interiorColumns:alpha:beta:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
|
||||
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
|
||||
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
|
||||
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
|
||||
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
|
||||
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
|
||||
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
vectorDescriptorWithLength_dataType,
|
||||
"vectorDescriptorWithLength:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
|
||||
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
initWithDevice_transpose_rows_columns_alpha_beta,
|
||||
"initWithDevice:transpose:rows:columns:alpha:beta:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
|
||||
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
|
||||
} // namespace MTL::Private::Selector
|
||||
|
||||
namespace MPS {
|
||||
|
||||
typedef enum DataType : uint32_t {
|
||||
DataTypeFloatBit = 0x10000000,
|
||||
DataTypeAlternateEncodingBit = 0x80000000,
|
||||
DataTypeFloat16 = DataTypeFloatBit | 16,
|
||||
DataTypeFloat32 = DataTypeFloatBit | 32,
|
||||
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
|
||||
} DataType;
|
||||
|
||||
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
|
||||
public:
|
||||
static class MatrixDescriptor* matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger dataType);
|
||||
static class MatrixDescriptor* matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger matrices,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger matrixBytes,
|
||||
NS::UInteger dataType);
|
||||
NS::UInteger rows() const;
|
||||
};
|
||||
|
||||
class Matrix : public NS::Referencing<Matrix> {
|
||||
public:
|
||||
static class Matrix* alloc();
|
||||
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
||||
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
||||
};
|
||||
|
||||
class Kernel : public NS::Referencing<Kernel> {
|
||||
public:
|
||||
NS::String* label() const;
|
||||
MTL::Device* device() const;
|
||||
};
|
||||
|
||||
class MatrixMultiplication
|
||||
: public NS::Referencing<MatrixMultiplication, Kernel> {
|
||||
public:
|
||||
static class MatrixMultiplication* alloc();
|
||||
|
||||
MatrixMultiplication* init(
|
||||
MTL::Device* device,
|
||||
bool transposeLeft,
|
||||
bool transposeRight,
|
||||
NS::UInteger resultRows,
|
||||
NS::UInteger resultColumns,
|
||||
NS::UInteger interiorColumns,
|
||||
double alpha,
|
||||
double beta);
|
||||
|
||||
void encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* leftMatrix,
|
||||
Matrix* rightMatrix,
|
||||
Matrix* resultMatrix);
|
||||
|
||||
void setLeftMatrixOrigin(MTL::Origin origin);
|
||||
void setRightMatrixOrigin(MTL::Origin origin);
|
||||
void setResultMatrixOrigin(MTL::Origin origin);
|
||||
void setBatchStart(NS::UInteger batchStart);
|
||||
void setBatchSize(NS::UInteger batchSize);
|
||||
};
|
||||
|
||||
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
|
||||
public:
|
||||
static class VectorDescriptor* vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger dataType);
|
||||
static class VectorDescriptor* vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger vectors,
|
||||
NS::UInteger vectorBytes,
|
||||
NS::UInteger dataType);
|
||||
};
|
||||
|
||||
class Vector : public NS::Referencing<Vector> {
|
||||
public:
|
||||
static class Vector* alloc();
|
||||
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
||||
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
||||
};
|
||||
|
||||
class MatrixVectorMultiplication
|
||||
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
|
||||
public:
|
||||
static class MatrixVectorMultiplication* alloc();
|
||||
|
||||
MatrixVectorMultiplication* init(
|
||||
MTL::Device* device,
|
||||
bool transpose,
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
double alpha,
|
||||
double beta);
|
||||
|
||||
void encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* inputMatrix,
|
||||
Vector* inputVector,
|
||||
Vector* resultVector);
|
||||
};
|
||||
|
||||
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<MatrixDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
||||
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
|
||||
rows,
|
||||
columns,
|
||||
rowBytes,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger matrices,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger matrixBytes,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<MatrixDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
||||
_MPS_PRIVATE_SEL(
|
||||
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
|
||||
rows,
|
||||
columns,
|
||||
matrices,
|
||||
rowBytes,
|
||||
matrixBytes,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
|
||||
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
|
||||
}
|
||||
|
||||
_MTL_INLINE Matrix* Matrix::alloc() {
|
||||
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
|
||||
}
|
||||
|
||||
_MTL_INLINE Matrix* Matrix::init(
|
||||
MTL::Buffer* buffer,
|
||||
MatrixDescriptor* descriptor) {
|
||||
return Object::sendMessage<Matrix*>(
|
||||
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE Matrix* Matrix::init(
|
||||
const MTL::Buffer* buffer,
|
||||
MatrixDescriptor* descriptor) {
|
||||
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE NS::String* Kernel::label() const {
|
||||
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
|
||||
}
|
||||
|
||||
_MTL_INLINE MTL::Device* Kernel::device() const {
|
||||
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
|
||||
return NS::Object::alloc<MatrixMultiplication>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
|
||||
MTL::Device* device,
|
||||
bool transposeLeft,
|
||||
bool transposeRight,
|
||||
NS::UInteger resultRows,
|
||||
NS::UInteger resultColumns,
|
||||
NS::UInteger interiorColumns,
|
||||
double alpha,
|
||||
double beta) {
|
||||
return Object::sendMessage<MatrixMultiplication*>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(initWithDevice_),
|
||||
device,
|
||||
transposeLeft,
|
||||
transposeRight,
|
||||
resultRows,
|
||||
resultColumns,
|
||||
interiorColumns,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* leftMatrix,
|
||||
Matrix* rightMatrix,
|
||||
Matrix* resultMatrix) {
|
||||
return Object::sendMessage<void>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(
|
||||
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
|
||||
commandBuffer,
|
||||
leftMatrix,
|
||||
rightMatrix,
|
||||
resultMatrix);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
|
||||
Object::sendMessage<void>(
|
||||
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
|
||||
MTL::Origin origin) {
|
||||
Object::sendMessage<void>(
|
||||
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
|
||||
MTL::Origin origin) {
|
||||
Object::sendMessage<void>(
|
||||
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
|
||||
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
|
||||
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
|
||||
}
|
||||
|
||||
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<VectorDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
||||
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
|
||||
length,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger vectors,
|
||||
NS::UInteger vectorBytes,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<VectorDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
||||
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
|
||||
length,
|
||||
vectors,
|
||||
vectorBytes,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE Vector* Vector::alloc() {
|
||||
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
|
||||
}
|
||||
|
||||
_MTL_INLINE Vector* Vector::init(
|
||||
MTL::Buffer* buffer,
|
||||
VectorDescriptor* descriptor) {
|
||||
return Object::sendMessage<Vector*>(
|
||||
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE Vector* Vector::init(
|
||||
const MTL::Buffer* buffer,
|
||||
VectorDescriptor* descriptor) {
|
||||
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
|
||||
return NS::Object::alloc<MatrixVectorMultiplication>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
|
||||
MTL::Device* device,
|
||||
bool transpose,
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
double alpha,
|
||||
double beta) {
|
||||
return Object::sendMessage<MatrixVectorMultiplication*>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
|
||||
device,
|
||||
transpose,
|
||||
rows,
|
||||
columns,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* inputMatrix,
|
||||
Vector* inputVector,
|
||||
Vector* resultVector) {
|
||||
return Object::sendMessage<void>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(
|
||||
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
|
||||
commandBuffer,
|
||||
inputMatrix,
|
||||
inputVector,
|
||||
resultVector);
|
||||
}
|
||||
|
||||
} // namespace MPS
|
604
mlx/backend/metal/primitives.cpp
Normal file
604
mlx/backend/metal/primitives.cpp
Normal file
@@ -0,0 +1,604 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
if (bopt == General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
int rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = bopt == General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void unary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
auto& in = inputs[0];
|
||||
bool contig = in.flags().contiguous;
|
||||
if (contig) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::string tname = type_to_name(in);
|
||||
std::string opt_name = contig ? "v" : "g";
|
||||
auto kernel = d.get_kernel(opt_name + op + tname);
|
||||
|
||||
size_t nthreads = contig ? in.data_size() : in.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
if (!contig) {
|
||||
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
|
||||
compute_encoder->setBytes(
|
||||
in.strides().data(), in.ndim() * sizeof(size_t), 3);
|
||||
int ndim = in.ndim();
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||
}
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "abs");
|
||||
}
|
||||
|
||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "add");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
|
||||
enc->setBytes(&start, sizeof(T), 0);
|
||||
T step = next - start;
|
||||
enc->setBytes(&step, sizeof(T), 1);
|
||||
}
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel("arange" + type_to_name(out));
|
||||
size_t nthreads = out.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_: // unsupported
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support bool");
|
||||
case uint8:
|
||||
arange_set_scalars<uint8_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case uint16:
|
||||
arange_set_scalars<uint16_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case uint32:
|
||||
arange_set_scalars<uint32_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case uint64:
|
||||
arange_set_scalars<uint64_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case int8:
|
||||
arange_set_scalars<int8_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case int16:
|
||||
arange_set_scalars<int16_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case int32:
|
||||
arange_set_scalars<int32_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case int64:
|
||||
arange_set_scalars<int64_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case float16:
|
||||
arange_set_scalars<float16_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case float32:
|
||||
arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case bfloat16:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16");
|
||||
case complex64:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
|
||||
}
|
||||
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arccos");
|
||||
}
|
||||
|
||||
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arccosh");
|
||||
}
|
||||
|
||||
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arcsin");
|
||||
}
|
||||
|
||||
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arcsinh");
|
||||
}
|
||||
|
||||
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctan");
|
||||
}
|
||||
|
||||
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctanh");
|
||||
}
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case ArgReduce::ArgMin:
|
||||
op_name = "argmin_";
|
||||
break;
|
||||
case ArgReduce::ArgMax:
|
||||
op_name = "argmax_";
|
||||
break;
|
||||
}
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
std::vector<size_t> in_strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
std::vector<size_t> out_strides = out.strides();
|
||||
size_t axis_stride = in_strides[axis_];
|
||||
size_t axis_size = shape[axis_];
|
||||
if (out_strides.size() == in_strides.size()) {
|
||||
out_strides.erase(out_strides.begin() + axis_);
|
||||
}
|
||||
in_strides.erase(in_strides.begin() + axis_);
|
||||
shape.erase(shape.begin() + axis_);
|
||||
size_t ndim = shape.size();
|
||||
|
||||
// ArgReduce
|
||||
int simd_size = 32;
|
||||
int n_reads = 4;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name + type_to_name(in));
|
||||
NS::UInteger thread_group_size = std::min(
|
||||
(axis_size + n_reads - 1) / n_reads,
|
||||
kernel->maxTotalThreadsPerThreadgroup());
|
||||
// round up to the closest number divisible by simd_size
|
||||
thread_group_size =
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
size_t n_threads = out.size() * thread_group_size;
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
CopyType ctype =
|
||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(inputs[0], out, ctype);
|
||||
}
|
||||
|
||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : inputs) {
|
||||
sizes.push_back(p.shape(axis_));
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "cos");
|
||||
}
|
||||
|
||||
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "cosh");
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
||||
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "erf");
|
||||
}
|
||||
|
||||
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "erfinv");
|
||||
}
|
||||
|
||||
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "exp");
|
||||
}
|
||||
|
||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = inputs[0];
|
||||
CopyType ctype;
|
||||
if (in.data_size() == 1) {
|
||||
ctype = CopyType::Scalar;
|
||||
} else if (in.flags().contiguous) {
|
||||
ctype = CopyType::Vector;
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "ge");
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "geq");
|
||||
}
|
||||
|
||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "le");
|
||||
}
|
||||
|
||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op(inputs, out, "log");
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op(inputs, out, "log2");
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op(inputs, out, "log10");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "log1p");
|
||||
}
|
||||
|
||||
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "lnot");
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lae");
|
||||
}
|
||||
|
||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "max");
|
||||
}
|
||||
|
||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "neg");
|
||||
}
|
||||
|
||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "neq");
|
||||
}
|
||||
|
||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
auto& val = inputs[1];
|
||||
|
||||
// Padding value must be a scalar
|
||||
assert(val.size() == 1);
|
||||
|
||||
// Padding value, input and output must be of the same type
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
// Fill output with val
|
||||
copy_gpu(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes_.size(); i++) {
|
||||
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size_[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
|
||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "pow");
|
||||
}
|
||||
|
||||
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
size_t num_keys = keys.size() / 2;
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
size_t half_size = out_per_key / 2;
|
||||
bool odd = out_per_key % 2;
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kname = keys.flags().row_contiguous ? "rbitsc" : "rbits";
|
||||
auto kernel = d.get_kernel(kname);
|
||||
|
||||
// organize into grid nkeys x elem_per_key
|
||||
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
auto nthreads = std::min(num_keys * (half_size + odd), thread_group_size);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, keys, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&odd, sizeof(bool), 2);
|
||||
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
|
||||
|
||||
if (!keys.flags().row_contiguous) {
|
||||
int ndim = keys.ndim();
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||
compute_encoder->setBytes(
|
||||
keys.shape().data(), keys.ndim() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
|
||||
}
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous) {
|
||||
auto flags = in.flags();
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sigmoid");
|
||||
}
|
||||
|
||||
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sign");
|
||||
}
|
||||
|
||||
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sin");
|
||||
}
|
||||
|
||||
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sinh");
|
||||
}
|
||||
|
||||
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "square");
|
||||
}
|
||||
|
||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (recip_) {
|
||||
unary_op(inputs, out, "rsqrt");
|
||||
} else {
|
||||
unary_op(inputs, out, "sqrt");
|
||||
}
|
||||
}
|
||||
|
||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "sub");
|
||||
}
|
||||
|
||||
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "tan");
|
||||
}
|
||||
|
||||
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "tanh");
|
||||
}
|
||||
|
||||
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
130
mlx/backend/metal/scan.cpp
Normal file
130
mlx/backend/metal/scan.cpp
Normal file
@@ -0,0 +1,130 @@
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure contiguity
|
||||
std::vector<array> copies;
|
||||
auto in = inputs[0];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
in = arr_copy;
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
if (in.strides()[axis_] == 1) {
|
||||
kname << "contiguous_scan_";
|
||||
if (reverse_) {
|
||||
kname << "reverse_";
|
||||
}
|
||||
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||
switch (reduce_type_) {
|
||||
case Scan::Sum:
|
||||
kname << "sum_";
|
||||
break;
|
||||
case Scan::Prod:
|
||||
kname << "prod_";
|
||||
break;
|
||||
case Scan::Max:
|
||||
kname << "max_";
|
||||
break;
|
||||
case Scan::Min:
|
||||
kname << "min_";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
int elements_per_simd = n_reads * 32;
|
||||
int thread_groups = in.size() / size;
|
||||
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (size < n_reads * 1024) {
|
||||
thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) *
|
||||
elements_per_simd;
|
||||
} else if (size < n_reads * 2048) {
|
||||
thread_group_size =
|
||||
((size / 2 + elements_per_simd - 1) / elements_per_simd) *
|
||||
elements_per_simd;
|
||||
}
|
||||
thread_group_size = std::min(
|
||||
thread_group_size,
|
||||
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
|
||||
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
kname << "strided_scan_";
|
||||
if (reverse_) {
|
||||
kname << "reverse_";
|
||||
}
|
||||
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||
switch (reduce_type_) {
|
||||
case Scan::Sum:
|
||||
kname << "sum_";
|
||||
break;
|
||||
case Scan::Prod:
|
||||
kname << "prod_";
|
||||
break;
|
||||
case Scan::Max:
|
||||
kname << "max_";
|
||||
break;
|
||||
case Scan::Min:
|
||||
kname << "min_";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
size_t stride = in.strides()[axis_];
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
int tile_x = 32;
|
||||
int tile_y = 32;
|
||||
int elements_per_tile_x = tile_x * n_reads;
|
||||
int grid_y = in.size() / size / stride;
|
||||
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
|
||||
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
|
||||
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (copies.size() > 0) {
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
167
mlx/backend/metal/utils.h
Normal file
167
mlx/backend/metal/utils.h
Normal file
@@ -0,0 +1,167 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
MTL::ArgumentEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
// MTL::Resource usage through argument buffer needs to be explicity
|
||||
// flagged to enable hazard tracking
|
||||
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
||||
}
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
}
|
||||
|
||||
std::string type_to_name(const array& a) {
|
||||
std::string tname;
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
tname = "bool_";
|
||||
break;
|
||||
case uint8:
|
||||
tname = "uint8";
|
||||
break;
|
||||
case uint16:
|
||||
tname = "uint16";
|
||||
break;
|
||||
case uint32:
|
||||
tname = "uint32";
|
||||
break;
|
||||
case uint64:
|
||||
tname = "uint64";
|
||||
break;
|
||||
case int8:
|
||||
tname = "int8";
|
||||
break;
|
||||
case int16:
|
||||
tname = "int16";
|
||||
break;
|
||||
case int32:
|
||||
tname = "int32";
|
||||
break;
|
||||
case int64:
|
||||
tname = "int64";
|
||||
break;
|
||||
case float16:
|
||||
tname = "float16";
|
||||
break;
|
||||
case float32:
|
||||
tname = "float32";
|
||||
break;
|
||||
case bfloat16:
|
||||
tname = "bfloat16";
|
||||
break;
|
||||
case complex64:
|
||||
tname = "complex64";
|
||||
break;
|
||||
}
|
||||
return tname;
|
||||
}
|
||||
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == 10) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
|
||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||
// should return {{2, 4}, {{1, 2}}}.
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (xs[0].ndim() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < xs[0].ndim(); i++) {
|
||||
bool contiguous = true;
|
||||
for (auto& x : xs) {
|
||||
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!contiguous) {
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
to_collapse.push_back(i);
|
||||
}
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<size_t>> out_strides(xs.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = xs[0].shape()[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= xs[0].shape()[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < xs.size(); j++) {
|
||||
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(Arrays... xs) {
|
||||
return collapse_contiguous_dims(
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
Reference in New Issue
Block a user