mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
angelos's commit files
This commit is contained in:
200
mlx/backend/metal/allocator.cpp
Normal file
200
mlx/backend/metal/allocator.cpp
Normal file
@@ -0,0 +1,200 @@
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
return metal::allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
return static_cast<MTL::Buffer*>(ptr_)->contents();
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
namespace metal {
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(MTL::Device* device)
|
||||
: device_(device),
|
||||
head_(nullptr),
|
||||
tail_(nullptr),
|
||||
pool_size_(0),
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void BufferCache::clear() {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf)
|
||||
holder->buf->release();
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
|
||||
// Find the closest buffer in pool
|
||||
MTL::Buffer* pbuf = nullptr;
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
|
||||
// Make sure we use > 50% of the available memory
|
||||
while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) {
|
||||
// Collect from the cache
|
||||
pbuf = it->second->buf;
|
||||
// Remove from cache
|
||||
remove_from_list(it->second);
|
||||
delete it->second;
|
||||
it = buffer_pool_.erase(it);
|
||||
}
|
||||
|
||||
if (pbuf) {
|
||||
pool_size_ -= pbuf->length();
|
||||
}
|
||||
|
||||
return pbuf;
|
||||
}
|
||||
|
||||
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
|
||||
// Add to cache
|
||||
if (buf) {
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
pool_size_ += buf->length();
|
||||
buffer_pool_.insert({buf->length(), bh});
|
||||
}
|
||||
}
|
||||
|
||||
size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_;
|
||||
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
size_t old_pool_size = pool_size_;
|
||||
clear();
|
||||
return old_pool_size;
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
if (tail_->buf) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
}
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return total_bytes_freed;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
||||
if (!to_add)
|
||||
return;
|
||||
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
if (!to_remove)
|
||||
return;
|
||||
|
||||
// If in the middle
|
||||
if (to_remove->prev && to_remove->next) {
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // If tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // If head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // If only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
to_remove->prev = nullptr;
|
||||
to_remove->next = nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
buffer_cache_(device_),
|
||||
peak_allocated_size_(0),
|
||||
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size) {
|
||||
// Align up memory
|
||||
if (size > vm_page_size) {
|
||||
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
||||
}
|
||||
|
||||
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
|
||||
// Prepare to allocate new memory as needed
|
||||
if (!buf) {
|
||||
// If we are under very high memoory pressure, we don't allocate further
|
||||
if (device_->currentAllocatedSize() >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
// If we are still under memory pressure, try cleaning cache
|
||||
if (buffer_cache_.can_garbage_collect()) {
|
||||
buffer_cache_.release_cached_buffers(size);
|
||||
}
|
||||
|
||||
// Allocate new buffer if needed
|
||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||
res_opt |= MTL::ResourceHazardTrackingModeTracked;
|
||||
buf = device_->newBuffer(size, res_opt);
|
||||
}
|
||||
|
||||
peak_allocated_size_ =
|
||||
std::max(peak_allocated_size_, device_->currentAllocatedSize());
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
}
|
||||
|
||||
MetalAllocator& allocator() {
|
||||
static MetalAllocator allocator_;
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
} // namespace mlx::core
|
76
mlx/backend/metal/allocator.h
Normal file
76
mlx/backend/metal/allocator.h
Normal file
@@ -0,0 +1,76 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
namespace {
|
||||
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(MTL::Device* device);
|
||||
~BufferCache();
|
||||
void clear();
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
size_t release_cached_buffers(size_t min_bytes_to_free);
|
||||
|
||||
bool can_garbage_collect() {
|
||||
return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
|
||||
|
||||
BufferHolder* prev;
|
||||
BufferHolder* next;
|
||||
MTL::Buffer* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
MTL::Device* device_;
|
||||
std::mutex cache_mutex_;
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_;
|
||||
BufferHolder* tail_;
|
||||
size_t pool_size_;
|
||||
size_t gc_limit_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
|
||||
private:
|
||||
MTL::Device* device_;
|
||||
MetalAllocator();
|
||||
friend MetalAllocator& allocator();
|
||||
|
||||
// Caching allocator
|
||||
BufferCache buffer_cache_;
|
||||
|
||||
// Allocation stats
|
||||
size_t peak_allocated_size_;
|
||||
size_t block_limit_;
|
||||
};
|
||||
|
||||
MetalAllocator& allocator();
|
||||
|
||||
} // namespace mlx::core::metal
|
555
mlx/backend/metal/conv.cpp
Normal file
555
mlx/backend/metal/conv.cpp
Normal file
@@ -0,0 +1,555 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void explicit_gemm_conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<1>& conv_params) {
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {
|
||||
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * conv_params.str[0],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2]};
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {
|
||||
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
/*b = */ wt,
|
||||
/*c = */ out,
|
||||
/*M = */ strided_reshape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ strided_reshape[1],
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ strided_reshape[1],
|
||||
/*b_cols = */ strided_reshape[1],
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(2),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int dil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
if (wt_dilation[0] == 1) {
|
||||
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to fallback conv
|
||||
else {
|
||||
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
|
||||
}
|
||||
}
|
||||
|
||||
void slow_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
int bm = 16, bn = 8;
|
||||
int tm = 4, tn = 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
|
||||
<< "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
||||
|
||||
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
|
||||
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
|
||||
size_t grid_dim_z = conv_params.N;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
|
||||
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
|
||||
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
conv_params.C};
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||
conv_params.pad[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {
|
||||
conv_params.N,
|
||||
conv_params.oS[0],
|
||||
conv_params.oS[1],
|
||||
conv_params.wS[0],
|
||||
conv_params.wS[1],
|
||||
conv_params.C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * conv_params.str[0],
|
||||
in_padded.strides()[2] * conv_params.str[1],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]};
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {
|
||||
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
|
||||
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
/*b = */ wt,
|
||||
/*c = */ out,
|
||||
/*M = */ strided_reshape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ strided_reshape[1],
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ strided_reshape[1],
|
||||
/*b_cols = */ strided_reshape[1],
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params,
|
||||
std::vector<array>& copies_w) {
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
conv_params.C};
|
||||
|
||||
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
||||
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
||||
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||
conv_params.pad[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
|
||||
copies_w.push_back(in_padded_slice);
|
||||
copies_w.push_back(in_padded);
|
||||
copies_w.push_back(zero_arr);
|
||||
|
||||
MLXConvParams<2> conv_params_updated{
|
||||
/* const int N = */ in_padded.shape(0),
|
||||
/* const int C = */ in_padded.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int str[NDIM] = */ {1, 1},
|
||||
/* const int pad[NDIM] = */ {0, 0},
|
||||
/* const int dil[NDIM] = */ {1, 1},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in_padded.strides()[0],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
};
|
||||
|
||||
int O_c = conv_params.O;
|
||||
int C_c = conv_params.C;
|
||||
|
||||
int N_tiles_n = conv_params.N;
|
||||
int N_tiles_h = (conv_params.oS[0] + 5) / 6;
|
||||
int N_tiles_w = (conv_params.oS[1] + 5) / 6;
|
||||
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
||||
|
||||
// Do filter transform
|
||||
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
|
||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||
copies_w.push_back(filt_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
int bo = 4;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, wt, 0);
|
||||
set_array_buffer(compute_encoder, filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
|
||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||
copies_w.push_back(inp_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in_padded, 0);
|
||||
set_array_buffer(compute_encoder, inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
|
||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
std::vector<array> empty_copies;
|
||||
mlx_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ inp_wg,
|
||||
/*b = */ filt_wg,
|
||||
/*c = */ out_wg,
|
||||
/*M = */ N_tiles,
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ conv_params.C,
|
||||
/*batch_size_out = */ 8 * 8,
|
||||
/*a_cols = */ conv_params.C,
|
||||
/*b_cols = */ conv_params.O,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ false,
|
||||
/*copies = */ empty_copies);
|
||||
}
|
||||
|
||||
// Do output transform
|
||||
{
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, out_wg, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<2> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
||||
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
};
|
||||
|
||||
// Direct to winograd conv
|
||||
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
|
||||
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
|
||||
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
|
||||
conv_params.dil[1] == 1) {
|
||||
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
|
||||
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
|
||||
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to fallback conv
|
||||
else {
|
||||
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure contiguity
|
||||
std::vector<array> copies;
|
||||
auto in = inputs[0];
|
||||
auto wt = inputs[1];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
in = arr_copy;
|
||||
}
|
||||
if (!wt.flags().row_contiguous) {
|
||||
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
|
||||
copy_gpu(wt, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
wt = arr_copy;
|
||||
}
|
||||
|
||||
// 2D conv
|
||||
if (out.ndim() == 4) {
|
||||
conv_2D_gpu(
|
||||
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
|
||||
}
|
||||
// 1D conv
|
||||
else if (out.ndim() == 3) {
|
||||
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
}
|
||||
// Throw error
|
||||
else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions.");
|
||||
}
|
||||
|
||||
// Clear copies
|
||||
if (copies.size() > 0) {
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
16
mlx/backend/metal/copy.h
Normal file
16
mlx/backend/metal/copy.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||
void copy_gpu_inplace(
|
||||
const array& src,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
30
mlx/backend/metal/kernels/arange.metal
Normal file
30
mlx/backend/metal/kernels/arange.metal
Normal file
@@ -0,0 +1,30 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void arange(
|
||||
constant const T& start,
|
||||
constant const T& step,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = start + index * step;
|
||||
}
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] \
|
||||
[[kernel]] void arange<type>( \
|
||||
constant const type& start, \
|
||||
constant const type& step, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
instantiate_arange(uint64, uint64_t)
|
||||
instantiate_arange(int8, int8_t)
|
||||
instantiate_arange(int16, int16_t)
|
||||
instantiate_arange(int32, int32_t)
|
||||
instantiate_arange(int64, int64_t)
|
||||
instantiate_arange(float16, half)
|
||||
instantiate_arange(float32, float)
|
||||
instantiate_arange(bfloat16, bfloat16_t)
|
208
mlx/backend/metal/kernels/arg_reduce.metal
Normal file
208
mlx/backend/metal/kernels/arg_reduce.metal
Normal file
@@ -0,0 +1,208 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename U>
|
||||
struct IndexValPair {
|
||||
uint32_t index;
|
||||
U val;
|
||||
|
||||
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct ArgMin {
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i=0; i<N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset+i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct ArgMax {
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i=0; i<N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset+i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||
}
|
||||
|
||||
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
||||
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||
return IndexValPair<U>(
|
||||
simd_shuffle_down(data.index, delta),
|
||||
simd_shuffle_down(data.val, delta)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename Op, int N_READS>
|
||||
[[kernel]] void arg_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device uint32_t *out [[buffer(1)]],
|
||||
const device int *shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
const device size_t& axis_stride [[buffer(6)]],
|
||||
const device size_t& axis_size [[buffer(7)]],
|
||||
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
||||
// and stride are provided in axis_stride and axis_size.
|
||||
//
|
||||
// Note: in shape == out shape with this convention.
|
||||
//
|
||||
// The sketch of the kernel is as follows.
|
||||
// 1. Launch prod(shape) * thread_group_size threads.
|
||||
// 2. Loop ceildiv(axis_size / lsize) times
|
||||
// 3. Read input values
|
||||
// 4. Reduce among them and go to 3
|
||||
// 4. Reduce in each simd_group
|
||||
// 6. Write in the thread local memory
|
||||
// 6. Reduce them accross thread group
|
||||
// 7. Write the output without need for atomic
|
||||
Op op;
|
||||
|
||||
// Compute the input/output index. There is one beginning and one output for
|
||||
// the whole threadgroup.
|
||||
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
|
||||
|
||||
IndexValPair<T> best(0, Op::init);
|
||||
|
||||
// Loop over the reduction axis in lsize*N_READS buckets
|
||||
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||
// Read the current value
|
||||
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
|
||||
uint32_t offset = current_index;
|
||||
const device T * current_in = in + in_idx + current_index * axis_stride;
|
||||
T vals[N_READS];
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
|
||||
current_index++;
|
||||
current_in += axis_stride;
|
||||
}
|
||||
best = op.template reduce_many<N_READS>(best, vals, offset);
|
||||
}
|
||||
// At this point we have reduced the axis into thread group best values so we
|
||||
// need to reduce across the thread group.
|
||||
|
||||
// First per simd reduction.
|
||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||
best = op.reduce(best, neighbor);
|
||||
}
|
||||
|
||||
// Write to the threadgroup memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_data[simd_group_id] = best;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the appropriate value from local data and perform one simd reduction
|
||||
uint simd_groups = ceildiv(lsize, simd_size);
|
||||
if (simd_lane_id < simd_groups) {
|
||||
best = local_data[simd_lane_id];
|
||||
}
|
||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||
best = op.reduce(best, neighbor);
|
||||
}
|
||||
|
||||
// Finally write the output
|
||||
if (lid == 0) {
|
||||
out[out_idx] = best.index;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_arg_reduce_helper(name, itype, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device uint32_t * out [[buffer(1)]], \
|
||||
const device int *shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
const device size_t& axis_stride [[buffer(6)]], \
|
||||
const device size_t& axis_size [[buffer(7)]], \
|
||||
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_arg_reduce(name, itype) \
|
||||
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
||||
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
||||
|
||||
instantiate_arg_reduce(bool_, bool)
|
||||
instantiate_arg_reduce(uint8, uint8_t)
|
||||
instantiate_arg_reduce(uint16, uint16_t)
|
||||
instantiate_arg_reduce(uint32, uint32_t)
|
||||
instantiate_arg_reduce(uint64, uint64_t)
|
||||
instantiate_arg_reduce(int8, int8_t)
|
||||
instantiate_arg_reduce(int16, int16_t)
|
||||
instantiate_arg_reduce(int32, int32_t)
|
||||
instantiate_arg_reduce(int64, int64_t)
|
||||
instantiate_arg_reduce(float16, half)
|
||||
instantiate_arg_reduce(float32, float)
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t)
|
392
mlx/backend/metal/kernels/bf16_math.h
Normal file
392
mlx/backend/metal/kernels/bf16_math.h
Normal file
@@ -0,0 +1,392 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal math for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
|
||||
Following the Metal Shading Language Specification (Metal 3.1)
|
||||
|
||||
"bfloat is an extended itypeing point type that only allows implicit conversion
|
||||
to a type of greater itypeing point rank. While bfloat can be implicitly
|
||||
converted to itype, it cannot be implicitly converted to half, and neither
|
||||
itype nor half can be implicitly converted to bfloat."
|
||||
|
||||
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
||||
for bfloat and calling with an argument of type bfloat will result in that
|
||||
argument getting implicitly converted to itype which then returns an output
|
||||
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
||||
|
||||
This leads to situations where
|
||||
bfloat a = 5.0bf;
|
||||
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
||||
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
||||
|
||||
For the moment, I will be adding overloaded instantiations of the math
|
||||
functions to accordingly automatically handle the casting
|
||||
|
||||
*/
|
||||
|
||||
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
||||
\
|
||||
METAL_FUNC otype abs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acos(itype x) { \
|
||||
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acosh(itype x) { \
|
||||
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asin(itype x) { \
|
||||
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asinh(itype x) { \
|
||||
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan(itype y_over_x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan2(itype y, itype x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atanh(itype x) { \
|
||||
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype ceil(itype x) { \
|
||||
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cos(itype x) { \
|
||||
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cosh(itype x) { \
|
||||
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cospi(itype x) { \
|
||||
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype divide(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp(itype x) { \
|
||||
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp10(itype x) { \
|
||||
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp2(itype x) { \
|
||||
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fabs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fdim(itype x, itype y) { \
|
||||
ctype t = static_cast<ctype>(x - y); \
|
||||
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
||||
} \
|
||||
METAL_FUNC otype floor(itype x) { \
|
||||
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fma( \
|
||||
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
||||
} \
|
||||
METAL_FUNC otype fmax(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmod(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fract(itype x) { \
|
||||
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
||||
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
||||
} \
|
||||
METAL_FUNC otype ldexp(itype x, int k) { \
|
||||
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log(itype x) { \
|
||||
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log10(itype x) { \
|
||||
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log2(itype x) { \
|
||||
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype nextafter(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
||||
} \
|
||||
METAL_FUNC otype pow(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype powr(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rint(itype x) { \
|
||||
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype round(itype x) { \
|
||||
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rsqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sin(itype x) { \
|
||||
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinh(itype x) { \
|
||||
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinpi(itype x) { \
|
||||
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tan(itype x) { \
|
||||
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanh(itype x) { \
|
||||
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanpi(itype x) { \
|
||||
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype trunc(itype x) { \
|
||||
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
||||
}
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_MAYBE_FAST_MATH__);
|
||||
|
||||
namespace fast {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_FAST_MATH__);
|
||||
|
||||
} // namespace fast
|
||||
|
||||
namespace precise {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_PRECISE_MATH__);
|
||||
|
||||
} // namespace precise
|
||||
|
||||
} // namespace metal
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal simd for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_metal_simd_comm_funcs( \
|
||||
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
||||
\
|
||||
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||
itype data, itype filling_data, ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), \
|
||||
itype_to_ctype(filling_data), \
|
||||
delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||
itype data, itype filling_data, ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), \
|
||||
itype_to_ctype(filling_data), \
|
||||
delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
||||
}
|
||||
|
||||
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
||||
\
|
||||
METAL_FUNC otype simd_max(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_min(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_product(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_sum(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_xor(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
||||
#else
|
||||
|
||||
#define bfloat16_to_uint16(x) x.bits_
|
||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||
|
||||
#endif
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_simd_comm_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
uint16_t,
|
||||
bfloat16_to_uint16,
|
||||
uint16_to_bfloat16);
|
||||
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
||||
|
||||
} // namespace metal
|
553
mlx/backend/metal/kernels/conv.metal
Normal file
553
mlx/backend/metal/kernels/conv.metal
Normal file
@@ -0,0 +1,553 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/gemm/conv.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Slow and naive kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const int BC = 16>
|
||||
[[kernel]] void naive_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
out += tid.z * params.out_strides[0];
|
||||
in += tid.z * params.in_strides[0];
|
||||
|
||||
int out_o = tid.y * BN * TN + lid.y * TN;
|
||||
int out_hw = tid.x * BM * TM + lid.x * TM;
|
||||
|
||||
int out_h[TM];
|
||||
int out_w[TN];
|
||||
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
int mm = (out_hw + m);
|
||||
out_h[m] = mm / params.oS[1];
|
||||
out_w[m] = mm % params.oS[1];
|
||||
}
|
||||
|
||||
|
||||
T in_local[TM];
|
||||
T wt_local[TN];
|
||||
T out_local[TM * TN] = {T(0)};
|
||||
|
||||
for(int h = 0; h < params.wS[0]; ++h) {
|
||||
for(int w = 0; w < params.wS[1]; ++w) {
|
||||
for(int c = 0; c < params.C; ++c) {
|
||||
|
||||
// Local in
|
||||
for(int m = 0; m < TM; m++) {
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
|
||||
|
||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
|
||||
}
|
||||
|
||||
// Load weight
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
int o = out_o + n;
|
||||
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
|
||||
h * params.wt_strides[1] +
|
||||
w * params.wt_strides[2] + c] : T(0);
|
||||
}
|
||||
|
||||
// Accumulate
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for(int n = 0; n < TN; ++n) {
|
||||
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for(int n = 0; n < TN; ++n) {
|
||||
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
|
||||
out[out_h[m] * params.out_strides[1] +
|
||||
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Instantiations
|
||||
|
||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_naive_conv_2d_blocks(name, itype) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
||||
|
||||
instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Implicit gemm kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
|
||||
|
||||
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
||||
|
||||
gemm_kernel::run(
|
||||
in, wt, out,
|
||||
params, tgp_memory,
|
||||
tid, lid, simd_gid, simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
|
||||
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int M, int R, int S>
|
||||
struct WinogradTransforms {
|
||||
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WinogradTransforms<6, 3, 8> {
|
||||
MLX_MTL_CONST int OUT_TILE_SIZE = 6;
|
||||
MLX_MTL_CONST int FILTER_SIZE = 3;
|
||||
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
|
||||
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
|
||||
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
||||
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
||||
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
||||
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
||||
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
||||
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
};
|
||||
|
||||
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
||||
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
||||
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
||||
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
||||
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
||||
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
};
|
||||
|
||||
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00, 0.00, 0.00},
|
||||
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
|
||||
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
|
||||
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
|
||||
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
|
||||
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
|
||||
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
|
||||
{ 0.00, 0.00, 1.00},
|
||||
};
|
||||
};
|
||||
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
||||
|
||||
template <typename T,
|
||||
int BC = 32,
|
||||
int BO = 4,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
|
||||
const device T* wt_in [[buffer(0)]],
|
||||
device T* wt_out [[buffer(1)]],
|
||||
const constant int& C [[buffer(2)]],
|
||||
const constant int& O [[buffer(3)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
|
||||
// Get lane position in simdgroup
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize G matrix
|
||||
simdgroup_matrix<T, 8, 8> G;
|
||||
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
|
||||
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
|
||||
|
||||
// Initialize Gt matrix
|
||||
simdgroup_matrix<T, 8, 8> Gt;
|
||||
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
|
||||
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
|
||||
|
||||
// Move to the correct output filter
|
||||
size_t ko = BO * tid + simd_group_id;
|
||||
wt_in += ko * R * R * C;
|
||||
|
||||
// wt_out is stored transposed (A x A x C x O)
|
||||
short ohw_0 = sm * 8 + sn;
|
||||
short ohw_1 = sm * 8 + sn + 1;
|
||||
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
|
||||
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
|
||||
|
||||
// Prepare shared memory
|
||||
threadgroup T Ws[BO][R][R][BC];
|
||||
|
||||
// Loop over C
|
||||
for(int bc = 0; bc < C; bc += BC) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read into shared memory
|
||||
for(int kh = 0; kh < R; ++kh) {
|
||||
for(int kw = 0; kw < R; ++kw) {
|
||||
for(int kc = simd_lane_id; kc < BC; kc += 32) {
|
||||
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for(int c = 0; c < BC; ++c) {
|
||||
simdgroup_matrix<T, 8, 8> g;
|
||||
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||
|
||||
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
||||
wt_out_0[c * O] = g_out.thread_elements()[0];
|
||||
wt_out_1[c * O] = g_out.thread_elements()[1];
|
||||
}
|
||||
|
||||
wt_in += BC;
|
||||
wt_out_0 += BC * O;
|
||||
wt_out_1 += BC * O;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
||||
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
|
||||
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
|
||||
const device itype* wt_in [[buffer(0)]],\
|
||||
device itype* wt_out [[buffer(1)]],\
|
||||
const constant int& C [[buffer(2)]],\
|
||||
const constant int& O [[buffer(3)]],\
|
||||
uint tid [[threadgroup_position_in_grid]],\
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
template <typename T,
|
||||
int BC,
|
||||
int WM,
|
||||
int WN,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
|
||||
const device T* inp_in [[buffer(0)]],
|
||||
device T* inp_out [[buffer(1)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
constexpr int A = WGT::IN_TILE_SIZE;
|
||||
constexpr int N_SIMD_GROUPS = WM * WN;
|
||||
|
||||
// Get lane position in simdgroup
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize B matrix
|
||||
simdgroup_matrix<T, 8, 8> B;
|
||||
B.thread_elements()[0] = WGT::in_transform[sm][sn];
|
||||
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
|
||||
|
||||
// Initialize Bt matrix
|
||||
simdgroup_matrix<T, 8, 8> Bt;
|
||||
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
|
||||
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
|
||||
|
||||
// Resolve input tile
|
||||
constexpr int TH = (A / WM);
|
||||
constexpr int TW = (A / WN);
|
||||
int kh = TH * (simd_group_id / WN);
|
||||
int kw = TW * (simd_group_id % WN);
|
||||
int bh = M * tid.y + kh;
|
||||
int bw = M * tid.x + kw;
|
||||
|
||||
// Move to the correct input tile
|
||||
inp_in += tid.z * params.in_strides[0]
|
||||
+ bh * params.in_strides[1]
|
||||
+ bw * params.in_strides[2];
|
||||
|
||||
// Pre compute strides
|
||||
int jump_in[TH][TW];
|
||||
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
|
||||
}
|
||||
}
|
||||
|
||||
// inp_out is stored interleaved (A x A x tiles x C)
|
||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||
size_t ohw_0 = sm * 8 + sn;
|
||||
size_t ohw_1 = sm * 8 + sn + 1;
|
||||
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
||||
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
||||
|
||||
// Prepare shared memory
|
||||
threadgroup T Is[A][A][BC];
|
||||
|
||||
// Loop over C
|
||||
for(int bc = 0; bc < params.C; bc += BC) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read into shared memory
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
const device T* in_ptr = inp_in + jump_in[h][w];
|
||||
for(int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> I;
|
||||
I.thread_elements()[0] = Is[sm][sn][c];
|
||||
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
||||
|
||||
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
|
||||
inp_out_0[c] = I_out.thread_elements()[0];
|
||||
inp_out_1[c] = I_out.thread_elements()[1];
|
||||
}
|
||||
|
||||
inp_in += BC;
|
||||
inp_out_0 += BC;
|
||||
inp_out_1 += BC;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
|
||||
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
|
||||
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
|
||||
const device itype* inp_in [[buffer(0)]],\
|
||||
device itype* inp_out [[buffer(1)]],\
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
||||
uint3 tid [[threadgroup_position_in_grid]],\
|
||||
uint3 lid [[thread_position_in_threadgroup]],\
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
template <typename T,
|
||||
int BO,
|
||||
int WM,
|
||||
int WN,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
|
||||
const device T* out_in [[buffer(0)]],
|
||||
device T* out_out [[buffer(1)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
constexpr int N_SIMD_GROUPS = WM * WN;
|
||||
|
||||
// Get lane position in simdgroup
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize A matrix
|
||||
simdgroup_matrix<T, 8, 8> B;
|
||||
B.thread_elements()[0] = WGT::out_transform[sm][sn];
|
||||
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
|
||||
|
||||
// Initialize At matrix
|
||||
simdgroup_matrix<T, 8, 8> Bt;
|
||||
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
|
||||
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
|
||||
|
||||
// Out_in comes in shape (A x A x tiles x O)
|
||||
// We do transform and then write out to out_out in shape (N, H, W, O)
|
||||
|
||||
// Resolve output tile
|
||||
constexpr int TH = (M / WM);
|
||||
constexpr int TW = (M / WN);
|
||||
int kh = TH * (simd_group_id / WN);
|
||||
int kw = TW * (simd_group_id % WN);
|
||||
int bh = M * tid.y + kh;
|
||||
int bw = M * tid.x + kw;
|
||||
|
||||
// Move to the correct input tile
|
||||
out_out += tid.z * params.out_strides[0]
|
||||
+ bh * params.out_strides[1]
|
||||
+ bw * params.out_strides[2];
|
||||
|
||||
// Pre compute strides
|
||||
int jump_in[TH][TW];
|
||||
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
|
||||
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
// out_in is stored interleaved (A x A x tiles x O)
|
||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||
size_t ohw_0 = sm * 8 + sn;
|
||||
size_t ohw_1 = sm * 8 + sn + 1;
|
||||
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
||||
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
||||
|
||||
// Prepare shared memory
|
||||
threadgroup T Os[M][M][BO];
|
||||
|
||||
// Loop over O
|
||||
for(int bo = 0; bo < params.O; bo += BO) {
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> O_mat;
|
||||
O_mat.thread_elements()[0] = out_in_0[c];
|
||||
O_mat.thread_elements()[1] = out_in_1[c];
|
||||
|
||||
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
|
||||
if((sm < M) && (sn < M)) {
|
||||
Os[sm][sn][c] = O_out.thread_elements()[0];
|
||||
}
|
||||
if((sm < M) && ((sn + 1) < M)) {
|
||||
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read out from shared memory
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
if(jump_in[h][w] >= 0) {
|
||||
device T* out_ptr = out_out + jump_in[h][w];
|
||||
for(int c = simd_lane_id; c < BO; c += 32) {
|
||||
out_ptr[c] = Os[kh + h][kw + w][c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out_out += BO;
|
||||
out_in_0 += BO;
|
||||
out_in_1 += BO;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
|
||||
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
|
||||
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
|
||||
const device itype* out_in [[buffer(0)]],\
|
||||
device itype* out_out [[buffer(1)]],\
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
||||
uint3 tid [[threadgroup_position_in_grid]],\
|
||||
uint3 lid [[thread_position_in_threadgroup]],\
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_winograd_conv_2d(name, itype) \
|
||||
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
|
||||
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
||||
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
|
||||
|
||||
instantiate_winograd_conv_2d(float32, float);
|
||||
instantiate_winograd_conv_2d(float16, half);
|
269
mlx/backend/metal/kernels/copy.metal
Normal file
269
mlx/backend/metal/kernels/copy.metal
Normal file
@@ -0,0 +1,269 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
constant const size_t& dst_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
constant const size_t dst_strides[2],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
constant const size_t dst_strides[3],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
constant const size_t dst_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const size_t* dst_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_##ctype<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_" #dims)]] \
|
||||
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
constant const size_t dst_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_1")]] \
|
||||
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
constant const size_t& dst_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_2")]] \
|
||||
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
constant const size_t dst_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_3")]] \
|
||||
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
constant const size_t dst_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
|
||||
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name)]] \
|
||||
[[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const size_t* dst_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_copy("scopy" #tname, itype, otype, s) \
|
||||
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
||||
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
||||
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
||||
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
|
||||
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
||||
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
||||
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
||||
instantiate_copy_all(itname ##int64, itype, int64_t) \
|
||||
instantiate_copy_all(itname ##float16, itype, half) \
|
||||
instantiate_copy_all(itname ##float32, itype, float) \
|
||||
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
||||
instantiate_copy_all(itname ##complex64, itype, complex64_t)
|
||||
|
||||
instantiate_copy_itype(bool_, bool)
|
||||
instantiate_copy_itype(uint8, uint8_t)
|
||||
instantiate_copy_itype(uint16, uint16_t)
|
||||
instantiate_copy_itype(uint32, uint32_t)
|
||||
instantiate_copy_itype(uint64, uint64_t)
|
||||
instantiate_copy_itype(int8, int8_t)
|
||||
instantiate_copy_itype(int16, int16_t)
|
||||
instantiate_copy_itype(int32, int32_t)
|
||||
instantiate_copy_itype(int64, int64_t)
|
||||
instantiate_copy_itype(float16, half)
|
||||
instantiate_copy_itype(float32, float)
|
||||
instantiate_copy_itype(bfloat16, bfloat16_t)
|
||||
instantiate_copy_itype(complex64, complex64_t)
|
68
mlx/backend/metal/kernels/erf.h
Normal file
68
mlx/backend/metal/kernels/erf.h
Normal file
@@ -0,0 +1,68 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
/*
|
||||
* Approximation to the error function.
|
||||
* Based on code from:
|
||||
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
|
||||
*/
|
||||
float erf(float a) {
|
||||
float r, s, t, u;
|
||||
t = metal::abs(a);
|
||||
s = a * a;
|
||||
if (t > 0.927734375f) {
|
||||
// maximum error 0.99527 ulp
|
||||
r = metal::fma(
|
||||
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
|
||||
u = metal::fma(
|
||||
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
|
||||
r = metal::fma(r, s, u);
|
||||
r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
|
||||
r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
|
||||
r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
|
||||
r = metal::fma(r, t, -t);
|
||||
// TODO, replace with expm1 when implemented
|
||||
r = 1.0f - metal::exp(r);
|
||||
r = metal::copysign(r, a);
|
||||
} else {
|
||||
// maximum error 0.98929 ulp
|
||||
r = -5.96761703e-4f; // -0x1.38e000p-11
|
||||
r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
|
||||
r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
|
||||
r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
|
||||
r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
|
||||
r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
|
||||
r = metal::fma(r, a, a);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
float erfinv(float a) {
|
||||
auto t = metal::fma(a, 0.0f - a, 1.0f);
|
||||
t = metal::log(t);
|
||||
float p;
|
||||
if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
} else { // maximum ulp error = 2.35002
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
91
mlx/backend/metal/kernels/gemm.metal
Normal file
91
mlx/backend/metal/kernels/gemm.metal
Normal file
@@ -0,0 +1,91 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
const constant int &M [[buffer(3)]],
|
||||
const constant int &N [[buffer(4)]],
|
||||
const constant int &K [[buffer(5)]],
|
||||
const constant int &batch_stride_a [[buffer(6)]],
|
||||
const constant int &batch_stride_b [[buffer(7)]],
|
||||
const constant int &batch_stride_c [[buffer(8)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
|
||||
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
batch_stride_a, batch_stride_b, batch_stride_c,
|
||||
tgp_memory,
|
||||
simd_lane_id, simd_group_id, tid, lid
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *C [[buffer(2)]], \
|
||||
const constant int &M [[buffer(3)]], \
|
||||
const constant int &N [[buffer(4)]], \
|
||||
const constant int &K [[buffer(5)]], \
|
||||
const constant int &batch_stride_a [[buffer(6)]], \
|
||||
const constant int &batch_stride_b [[buffer(7)]], \
|
||||
const constant int &batch_stride_c [[buffer(8)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
// TODO: Accumulation in different type
|
99
mlx/backend/metal/kernels/random.metal
Normal file
99
mlx/backend/metal/kernels/random.metal
Normal file
@@ -0,0 +1,99 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
static constexpr constant uint32_t rotations[2][4] = {
|
||||
{13, 15, 26, 6},
|
||||
{17, 29, 16, 24}
|
||||
};
|
||||
|
||||
union rbits {
|
||||
uint2 val;
|
||||
uchar4 bytes[2];
|
||||
};
|
||||
|
||||
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
|
||||
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||
|
||||
rbits v;
|
||||
v.val.x = count.x + ks[0];
|
||||
v.val.y = count.y + ks[1];
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
for (auto r : rotations[i % 2]) {
|
||||
v.val.x += v.val.y;
|
||||
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
|
||||
v.val.y ^= v.val.x;
|
||||
}
|
||||
v.val.x += ks[(i + 1) % 3];
|
||||
v.val.y += ks[(i + 2) % 3] + i + 1;
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
[[kernel]] void rbitsc(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
auto key = uint2(keys[kidx], keys[kidx + 1]);
|
||||
auto half_size = grid_dim.y - odd;
|
||||
out += index.x * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
|
||||
auto bits = threefry2x32_hash(key, count);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.x + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[kernel]] void rbits(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
device const int& ndim,
|
||||
device const int* key_shape,
|
||||
device const size_t* key_strides,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim);
|
||||
auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim);
|
||||
auto key = uint2(keys[k1_elem], keys[k2_elem]);
|
||||
auto half_size = grid_dim.y - odd;
|
||||
out += index.x * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
|
||||
auto bits = threefry2x32_hash(key, count);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.x + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
536
mlx/backend/metal/kernels/reduce.metal
Normal file
536
mlx/backend/metal/kernels/reduce.metal
Normal file
@@ -0,0 +1,536 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
static constant uint8_t simd_size = 32;
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T *out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i" #name)]] \
|
||||
[[kernel]] void init_reduce<otype, op>( \
|
||||
device otype *out [[buffer(1)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// NB: this kernel assumes threads_per_threadgroup is at most
|
||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||
// complete the reduction in two steps of simd-level reductions.
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = Op::init;
|
||||
|
||||
in += gid * N_READS;
|
||||
|
||||
int r = 0;
|
||||
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||
U vals[N_READS] = {op.init};
|
||||
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = static_cast<U>(in[i]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(vals[i], total_val);
|
||||
}
|
||||
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Sepate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
T vals[N_READS];
|
||||
|
||||
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||
idx = idx < max_reads ? idx : max_reads - 1;
|
||||
vals[i] = in[idx];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
U val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduction within simd group
|
||||
total_val = op.simd_reduce(total_val);
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
|
||||
// Reduction within thread group
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
op.atomic_update(out, total_val);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] \
|
||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// General reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void general_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device int *in_shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
|
||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM>
|
||||
[[kernel]] void general_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device int *in_shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
|
||||
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
|
||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||
}
|
||||
|
||||
#define instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||
template [[host_name("general_reduce_" #name)]] \
|
||||
[[kernel]] void general_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device int *in_shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
|
||||
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
|
||||
[[kernel]] void general_reduce<itype, otype, op, n>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device int *in_shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_general_reduce(name, itype, otype, op) \
|
||||
instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Row atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device size_t& reduction_size [[buffer(2)]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
in += tid * reduction_size + lid * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize * N_READS;
|
||||
}
|
||||
|
||||
// Sepate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
|
||||
if(reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
int idx = min(i, max_reads - 1);
|
||||
vals[i] = static_cast<U>(in[idx]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
T val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if multiple simd groups
|
||||
if(reduction_size > simd_size) {
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
}
|
||||
// Update output
|
||||
if (lid == 0) {
|
||||
out[tid] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_" #name)]] \
|
||||
[[kernel]] void row_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const device size_t& reduction_size [[buffer(2)]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
inline void _contiguous_strided_reduce(
|
||||
const device T *in,
|
||||
device mlx_atomic<U> *out,
|
||||
threadgroup U *local_data,
|
||||
uint in_idx,
|
||||
uint out_idx,
|
||||
uint reduction_size,
|
||||
uint reduction_stride,
|
||||
uint2 tid,
|
||||
uint2 lid,
|
||||
uint2 lsize) {
|
||||
|
||||
Op op;
|
||||
T local_vals[N_READS];
|
||||
|
||||
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
||||
|
||||
for(uint r = 0; r < N_READS; r++) {
|
||||
uint offset = base_offset + r;
|
||||
offset = offset < reduction_size ? offset : reduction_size - 1;
|
||||
local_vals[r] = in[in_idx + offset * reduction_stride];
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
||||
total_val = op(static_cast<U>(total_val), local_vals[r]);
|
||||
}
|
||||
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(lid.y == 0) {
|
||||
U val = op.init;
|
||||
|
||||
for(uint i = 0; i < lsize.y; i++) {
|
||||
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||
}
|
||||
|
||||
op.atomic_update(out, val, out_idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
out_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_" #name)]] \
|
||||
[[kernel]] void col_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
|
||||
[[kernel]] void contiguous_strided_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const device int* in_shape [[buffer(5)]],
|
||||
const device size_t* in_strides [[buffer(6)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void contiguous_strided_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const device int* in_shape [[buffer(5)]],
|
||||
const device size_t* in_strides [[buffer(6)]],
|
||||
const device size_t& in_dim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||
template [[host_name("contiguous_strided_reduce_" #name)]] \
|
||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const device int* in_shape [[buffer(5)]], \
|
||||
const device size_t* in_strides [[buffer(6)]], \
|
||||
const device size_t& in_dim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
|
||||
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
|
||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const device int* in_shape [[buffer(5)]], \
|
||||
const device size_t* in_strides [[buffer(6)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_contiguous_strided(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_reduce(name, itype, otype, op) \
|
||||
instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_row_reduce(name, itype, otype, op) \
|
||||
instantiate_col_reduce(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided(name, itype, otype, op) \
|
||||
instantiate_general_reduce(name, itype, otype, op)
|
||||
|
||||
#define instantiate_same_reduce(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||
instantiate_reduce(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \
|
||||
instantiate_reduce(name ##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(name, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, float32, float, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_same_reduce(sum, uint8, uint8_t, Sum)
|
||||
instantiate_same_reduce(sum, uint16, uint16_t, Sum)
|
||||
instantiate_same_reduce(sum, uint32, uint32_t, Sum)
|
||||
instantiate_same_reduce(sum, int8, int8_t, Sum)
|
||||
instantiate_same_reduce(sum, int16, int16_t, Sum)
|
||||
instantiate_same_reduce(sum, int32, int32_t, Sum)
|
||||
instantiate_same_reduce(sum, float16, half, Sum)
|
||||
instantiate_same_reduce(sum, float32, float, Sum)
|
||||
|
||||
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
|
||||
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
|
||||
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
|
||||
instantiate_same_reduce(prod, int8, int8_t, Prod)
|
||||
instantiate_same_reduce(prod, int16, int16_t, Prod)
|
||||
instantiate_same_reduce(prod, int32, int32_t, Prod)
|
||||
instantiate_same_reduce(prod, float16, half, Prod)
|
||||
instantiate_same_reduce(prod, float32, float, Prod)
|
||||
|
||||
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
|
||||
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And)
|
||||
instantiate_reduce_from_types(and, bool, And)
|
||||
|
||||
instantiate_init_reduce(orbool_, bool, Or)
|
||||
instantiate_reduce_from_types(or, bool, Or)
|
||||
|
||||
// Compiler segfaulted with the names "min" or "max" ...
|
||||
instantiate_same_reduce(min_, uint8, uint8_t, Min)
|
||||
instantiate_same_reduce(min_, uint16, uint16_t, Min)
|
||||
instantiate_same_reduce(min_, uint32, uint32_t, Min)
|
||||
instantiate_same_reduce(min_, int8, int8_t, Min)
|
||||
instantiate_same_reduce(min_, int16, int16_t, Min)
|
||||
instantiate_same_reduce(min_, int32, int32_t, Min)
|
||||
instantiate_same_reduce(min_, float16, half, Min)
|
||||
instantiate_same_reduce(min_, float32, float, Min)
|
||||
|
||||
instantiate_same_reduce(max_, uint8, uint8_t, Max)
|
||||
instantiate_same_reduce(max_, uint16, uint16_t, Max)
|
||||
instantiate_same_reduce(max_, uint32, uint32_t, Max)
|
||||
instantiate_same_reduce(max_, int8, int8_t, Max)
|
||||
instantiate_same_reduce(max_, int16, int16_t, Max)
|
||||
instantiate_same_reduce(max_, int32, int32_t, Max)
|
||||
instantiate_same_reduce(max_, float16, half, Max)
|
||||
instantiate_same_reduce(max_, float32, float, Max)
|
||||
|
||||
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
||||
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
284
mlx/backend/metal/kernels/unary.metal
Normal file
284
mlx/backend/metal/kernels/unary.metal
Normal file
@@ -0,0 +1,284 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Abs {
|
||||
template <typename T> T operator()(T x) { return metal::abs(x); };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log(x); };
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T> T operator()(T x) { return log1p(x); };
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T> T operator()(T x) { return !x; };
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T> T operator()(T x) { return -x; };
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
|
||||
template <> uint32_t operator()(uint32_t x) { return x != 0; };
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T> T operator()(T x) { return x * x; };
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tan_a - tanh_b * t1) / denom,
|
||||
(tanh_b + tan_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tanh_a + tan_b * t1) / denom,
|
||||
(tan_b - tanh_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void unary_op_v(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = Op()(in[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void unary_op_g(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
device const int* in_shape,
|
||||
device const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
|
||||
out[index] = Op()(in[idx]);
|
||||
}
|
||||
|
||||
#define instantiate_unary_v(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void unary_op_v<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_unary_g(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void unary_op_g<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
device const int* in_shape, \
|
||||
device const size_t* in_strides, \
|
||||
device const int& ndim, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_unary_all(name, tname, type, op) \
|
||||
instantiate_unary_v("v" #name #tname, type, op) \
|
||||
instantiate_unary_g("g" #name #tname, type, op)
|
||||
|
||||
#define instantiate_unary_float(name, op) \
|
||||
instantiate_unary_all(name, float16, half, op) \
|
||||
instantiate_unary_all(name, float32, float, op) \
|
||||
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
|
||||
|
||||
#define instantiate_unary_types(name, op) \
|
||||
instantiate_unary_all(name, bool_, bool, op) \
|
||||
instantiate_unary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_unary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_unary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_unary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_unary_all(name, int8, int8_t, op) \
|
||||
instantiate_unary_all(name, int16, int16_t, op) \
|
||||
instantiate_unary_all(name, int32, int32_t, op) \
|
||||
instantiate_unary_all(name, int64, int64_t, op) \
|
||||
instantiate_unary_float(name, op)
|
||||
|
||||
instantiate_unary_types(abs, Abs)
|
||||
instantiate_unary_float(arccos, ArcCos)
|
||||
instantiate_unary_float(arccosh, ArcCosh)
|
||||
instantiate_unary_float(arcsin, ArcSin)
|
||||
instantiate_unary_float(arcsinh, ArcSinh)
|
||||
instantiate_unary_float(arctan, ArcTan)
|
||||
instantiate_unary_float(arctanh, ArcTanh)
|
||||
instantiate_unary_float(cos, Cos)
|
||||
instantiate_unary_float(cosh, Cosh)
|
||||
instantiate_unary_float(exp, Exp)
|
||||
instantiate_unary_float(log, Log)
|
||||
instantiate_unary_float(log2, Log2)
|
||||
instantiate_unary_float(log10, Log10)
|
||||
instantiate_unary_float(log1p, Log1p)
|
||||
instantiate_unary_types(neg, Negative)
|
||||
instantiate_unary_float(sigmoid, Sigmoid)
|
||||
instantiate_unary_float(erf, Erf)
|
||||
instantiate_unary_float(erfinv, ErfInv)
|
||||
instantiate_unary_types(sign, Sign)
|
||||
instantiate_unary_float(sin, Sin)
|
||||
instantiate_unary_float(sinh, Sinh)
|
||||
instantiate_unary_types(square, Square)
|
||||
instantiate_unary_float(sqrt, Sqrt)
|
||||
instantiate_unary_float(rsqrt, Rsqrt)
|
||||
instantiate_unary_float(tan, Tan)
|
||||
instantiate_unary_float(tanh, Tanh)
|
||||
|
||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
|
||||
instantiate_unary_all(exp, complex64, complex64_t, Exp)
|
||||
instantiate_unary_all(neg, complex64, complex64_t, Negative)
|
||||
instantiate_unary_all(sin, complex64, complex64_t, Sin)
|
||||
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
|
||||
instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||
|
||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
369
mlx/backend/metal/reduce.cpp
Normal file
369
mlx/backend/metal/reduce.cpp
Normal file
@@ -0,0 +1,369 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// Case wise reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
// All Reduce
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
// Get kernel and encode buffers
|
||||
size_t in_size = in.size();
|
||||
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in));
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
|
||||
// Set grid dimensions
|
||||
|
||||
// We make sure each thread has enough to do by making it read in
|
||||
// atleast n_reads inputs
|
||||
int n_reads = REDUCE_N_READS;
|
||||
|
||||
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||
// input
|
||||
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
||||
|
||||
// If the number of thread groups needed exceeds 1024, we reuse threads groups
|
||||
uint n_thread_groups =
|
||||
(mod_in_size + thread_group_size - 1) / thread_group_size;
|
||||
n_thread_groups = std::min(n_thread_groups, 1024u);
|
||||
uint nthreads = n_thread_groups * thread_group_size;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void row_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const std::vector<int>& axes_,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
|
||||
|
||||
int n_reads = REDUCE_N_READS;
|
||||
size_t reduction_size = in.size() / out.size();
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
|
||||
// Each thread group is responsible for 1 output
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||
|
||||
// Align thread group size with simd_size
|
||||
uint simd_size = kernel->threadExecutionWidth();
|
||||
thread_group_size =
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
size_t n_threads = out.size() * thread_group_size;
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void col_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const std::vector<int>& axes_,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
std::ostringstream kernel_name;
|
||||
|
||||
bool encode_in_shape = false;
|
||||
bool encode_ndim = false;
|
||||
|
||||
// If the slowest moving axis can be merged into the reductions,
|
||||
// we call the column reduce kernel
|
||||
// In this case, a linear index in the output corresponds to the
|
||||
// linear index in the input where the reduction starts
|
||||
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
|
||||
kernel_name << "col_reduce_" << op_name << type_to_name(in);
|
||||
}
|
||||
// Otherwise, while all the reduction axes can be merged, the mapping between
|
||||
// indices in the output and input require resolving using shapes and strides
|
||||
else {
|
||||
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
|
||||
encode_in_shape = true;
|
||||
|
||||
// We check for a viable template with the required number of dimensions
|
||||
// we only care about encoding non-reduced shapes and strides in the input
|
||||
size_t non_reducing_dims = in.ndim() - axes_.size();
|
||||
if (non_reducing_dims >= 1 &&
|
||||
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
||||
kernel_name << "_dim_" << non_reducing_dims;
|
||||
} else {
|
||||
encode_ndim = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto kernel = d.get_kernel(kernel_name.str());
|
||||
size_t in_size = in.size();
|
||||
size_t out_size = out.size();
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
// Calculate the number of inputs to reduce and the stride b/w them
|
||||
size_t reduction_size = 1;
|
||||
size_t in_ndim = in.ndim();
|
||||
size_t reduction_stride = in_size;
|
||||
|
||||
for (int i : axes_) {
|
||||
reduction_size *= in.shape(i);
|
||||
reduction_stride = std::min(reduction_stride, in.strides()[i]);
|
||||
}
|
||||
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
if (encode_in_shape) {
|
||||
// Obtain the non-reducing shape and strides of the input to encode
|
||||
std::vector<int> inp_shape_mod;
|
||||
std::vector<size_t> inp_strides_mod;
|
||||
|
||||
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
|
||||
if (j < axes_.size() && axes_[j] == i) {
|
||||
j++;
|
||||
} else {
|
||||
inp_shape_mod.push_back(in.shape(i));
|
||||
inp_strides_mod.push_back(in.strides()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
size_t ndim = inp_shape_mod.size();
|
||||
|
||||
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
|
||||
|
||||
if (encode_ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
|
||||
}
|
||||
}
|
||||
|
||||
// Select block dimensions
|
||||
|
||||
// Each thread reads 16 inputs to give it more work
|
||||
uint n_inputs_per_thread = REDUCE_N_READS;
|
||||
uint n_threads_per_output =
|
||||
(reduction_size + n_inputs_per_thread - 1) / n_inputs_per_thread;
|
||||
|
||||
// We spread outputs over the x dimension and inputs over the y dimension
|
||||
// Threads with the same lid.x in a given threadgroup work on the same
|
||||
// output and each thread in the y dimension accumlates for that output
|
||||
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
||||
uint threadgroup_dim_y =
|
||||
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
||||
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
|
||||
|
||||
uint n_threadgroups_x =
|
||||
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
|
||||
|
||||
uint n_threadgroups_y =
|
||||
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
|
||||
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
||||
|
||||
// We set shared memory to be exploited here for reductions within a
|
||||
// threadgroup - each thread must be able to update its accumulated output
|
||||
// Note: Each threadgroup should have 32kB of data in threadgroup memory
|
||||
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
|
||||
// This should be fine for floats, but we might need to revisit
|
||||
// if we ever come to doubles. In that case, we should also cut
|
||||
// down the number of threads we launch in a threadgroup
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 0);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void general_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const std::vector<int>& axes_,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
bool encode_ndim = true;
|
||||
std::ostringstream kernel_name;
|
||||
kernel_name << "general_reduce_" << op_name << type_to_name(in);
|
||||
|
||||
// Check for specialzed kernels for input ndim
|
||||
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
||||
kernel_name << "_dim_" << in.ndim();
|
||||
encode_ndim = false;
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name.str());
|
||||
size_t in_size = in.size();
|
||||
size_t ndim = in.ndim();
|
||||
|
||||
// We set the reducing strides to 0 to induce collisions for the reduction
|
||||
std::vector<size_t> out_strides(ndim);
|
||||
size_t stride = 1;
|
||||
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
|
||||
if (j >= 0 && axes_[j] == i) {
|
||||
out_strides[i] = 0;
|
||||
--j;
|
||||
} else {
|
||||
out_strides[i] = stride;
|
||||
stride *= in.shape(i);
|
||||
}
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||
if (encode_ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
}
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > in_size) {
|
||||
thread_group_size = in_size;
|
||||
}
|
||||
size_t nthreads = in_size;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// Main reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// TODO: Allow specific row and column reductions with types disabled
|
||||
// due to atomics ?
|
||||
if (size_of(in.dtype()) == 8) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make sure no identity reductions trickle down here
|
||||
assert(!axes_.empty());
|
||||
|
||||
// Continue with reduction operation
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
op_name = "and";
|
||||
break;
|
||||
case Reduce::Or:
|
||||
op_name = "or";
|
||||
break;
|
||||
case Reduce::Sum:
|
||||
op_name = "sum";
|
||||
break;
|
||||
case Reduce::Prod:
|
||||
op_name = out.dtype() == bool_ ? "and" : "prod";
|
||||
break;
|
||||
case Reduce::Min:
|
||||
op_name = out.dtype() == bool_ ? "and" : "min_";
|
||||
break;
|
||||
case Reduce::Max:
|
||||
op_name = out.dtype() == bool_ ? "or" : "max_";
|
||||
break;
|
||||
}
|
||||
|
||||
// Initialize output
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
|
||||
size_t nthreads = out.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, out, 0);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Reduce
|
||||
{
|
||||
// Check for contiguous data
|
||||
if (in.size() == in.data_size() &&
|
||||
(in.flags().row_contiguous || in.flags().col_contiguous)) {
|
||||
// Go to all reduce if reducing over all axes
|
||||
if (axes_.size() == in.ndim()) {
|
||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
||||
return;
|
||||
}
|
||||
// Use specialized kernels if the input is row contiguous and
|
||||
// the reducing axes can be merged into one
|
||||
else if (
|
||||
in.flags().row_contiguous && in.strides().back() == 1 &&
|
||||
(axes_.back() - axes_.front()) == axes_.size() - 1) {
|
||||
// If the fastest moving axis is being reduced, go to row reduce
|
||||
if (axes_[0] == (in.ndim() - axes_.size())) {
|
||||
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||
return;
|
||||
}
|
||||
// Otherwise go to to generalized strided reduce
|
||||
// Note: bool isn't support here yet due to the use of atomics
|
||||
// once that is updated, this should be the else condition of this
|
||||
// branch
|
||||
else if (in.dtype() != bool_) {
|
||||
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fall back to the general case
|
||||
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
Reference in New Issue
Block a user