Compare commits

...

2 Commits

Author SHA1 Message Date
Cheng
db5a7c6192 Add memory cache to CUDA backend (#2221)
* Move BufferCache out of allocator

* Add memory cache to cuda backend allocator

* Simplify BufferCache assuming buf can not be null
2025-05-30 12:12:54 -07:00
Awni Hannun
6ef2f67e7f 5bit quants (#2226)
* 5bit quants

* 5bit quants
2025-05-30 12:12:10 -07:00
12 changed files with 507 additions and 275 deletions

View File

@@ -0,0 +1,157 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cassert>
#include <functional>
#include <map>
namespace mlx::core {
template <typename T>
class BufferCache {
public:
BufferCache(
size_t page_size,
std::function<size_t(T*)> get_size,
std::function<void(T*)> free)
: page_size_(page_size),
get_size_(std::move(get_size)),
free_(std::move(free)) {}
~BufferCache() {
clear();
}
BufferCache(const BufferCache&) = delete;
BufferCache& operator=(const BufferCache&) = delete;
T* reuse_from_cache(size_t size) {
// Find the closest buffer in pool.
auto it = buffer_pool_.lower_bound(size);
if (it == buffer_pool_.end() ||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
return nullptr;
}
// Collect from the cache.
T* buf = it->second->buf;
pool_size_ -= it->first;
// Remove from record.
remove_from_list(it->second);
buffer_pool_.erase(it);
return buf;
}
void recycle_to_cache(T* buf) {
assert(buf);
// Add to cache.
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
size_t size = get_size_(buf);
pool_size_ += size;
buffer_pool_.emplace(size, bh);
}
int release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
// Release buffer.
size_t size = get_size_(tail_->buf);
total_bytes_freed += size;
free_(tail_->buf);
n_release++;
// Remove from record.
auto its = buffer_pool_.equal_range(size);
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
return el.second == tail_;
});
assert(it != buffer_pool_.end());
buffer_pool_.erase(it);
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
int clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
free_(holder->buf);
n_release++;
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
size_t cache_size() const {
return pool_size_;
}
size_t page_size() const {
return page_size_;
}
private:
struct BufferHolder {
public:
explicit BufferHolder(T* buf_) : buf(buf_) {}
BufferHolder* prev{nullptr};
BufferHolder* next{nullptr};
T* buf;
};
void add_at_head(BufferHolder* to_add) {
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void remove_from_list(BufferHolder* to_remove) {
if (to_remove->prev && to_remove->next) { // if middle
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // if tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // if head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // if only element
head_ = nullptr;
tail_ = nullptr;
}
delete to_remove;
}
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_{nullptr};
BufferHolder* tail_{nullptr};
size_t pool_size_{0};
const size_t page_size_;
std::function<size_t(T*)> get_size_;
std::function<void(T*)> free_;
};
} // namespace mlx::core

View File

@@ -13,9 +13,18 @@ namespace mlx::core {
namespace {
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
auto power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
static_assert(bits == 3 || bits == 5 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
@@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 5) {
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
@@ -46,8 +65,8 @@ void _qmm(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int pack_factor = get_pack_factor(bits, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -65,7 +84,7 @@ void _qmm(
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
if (bits == 3 || bits == 6) {
if constexpr (bits == 3 || bits == 5 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
@@ -104,8 +123,9 @@ void _qmm_t(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int pack_factor = get_pack_factor(bits, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -121,7 +141,7 @@ void _qmm_t(
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
if (bits == 3 || bits == 6) {
if constexpr (bits == 3 || bits == 5 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
@@ -304,6 +324,10 @@ void _qmm_dispatch_typed(
_qmm_dispatch_group<T, 4>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 5:
_qmm_dispatch_group<T, 5>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 6:
_qmm_dispatch_group<T, 6>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
@@ -613,9 +637,8 @@ void quantize(
float eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int el_per_int = get_pack_factor(bits, 32);
int bytes_per_pack = get_bytes_per_pack(bits);
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_size / group_size;
@@ -640,15 +663,21 @@ void quantize(
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
uint64_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
float w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - bias) / scale);
w_el = std::min(std::max(w_el, 0.0f), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
out_el |= static_cast<uint64_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else if (bits == 5) {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;

View File

@@ -6,6 +6,7 @@
#include <cuda_runtime.h>
#include <fmt/format.h>
#include <unistd.h>
#include <cassert>
@@ -13,24 +14,47 @@ namespace mlx::core {
namespace cu {
CudaAllocator::CudaAllocator() {
CudaAllocator::CudaAllocator()
: buffer_cache_(
getpagesize(),
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_;
}
Buffer CudaAllocator::malloc(size_t size) {
// TODO: Check memory limit.
auto* buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
// Find available buffer from cache.
std::unique_lock lock(mutex_);
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}
lock.unlock();
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
lock.lock();
}
std::lock_guard lock(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf};
}
@@ -40,26 +64,14 @@ void CudaAllocator::free(Buffer buffer) {
return;
}
// If free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([buffer]() { allocator().free(buffer); });
worker_->end_batch();
worker_->commit();
return;
}
std::unique_lock lock(mutex_);
active_memory_ -= buf->size;
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
lock.unlock();
cuda_free(buf);
}
size_t size = buf->size;
cudaFree(buf->data);
delete buf;
std::lock_guard lock(mutex_);
active_memory_ -= size;
}
size_t CudaAllocator::size(Buffer buffer) const {
@@ -98,6 +110,41 @@ size_t CudaAllocator::set_memory_limit(size_t limit) {
return limit;
}
size_t CudaAllocator::get_cache_memory() const {
return buffer_cache_.cache_size();
}
size_t CudaAllocator::set_cache_limit(size_t limit) {
std::lock_guard lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
}
void CudaAllocator::clear_cache() {
std::lock_guard lk(mutex_);
buffer_cache_.clear();
}
void CudaAllocator::cuda_free(CudaBuffer* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf->data);
delete buf;
}
CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This
@@ -138,17 +185,19 @@ size_t set_memory_limit(size_t limit) {
size_t get_memory_limit() {
return cu::allocator().get_memory_limit();
}
// TODO: Implement buffer cache.
size_t get_cache_memory() {
return 0;
return cu::allocator().get_cache_memory();
}
size_t set_cache_limit(size_t) {
return 0;
size_t set_cache_limit(size_t limit) {
return cu::allocator().set_cache_limit(limit);
}
void clear_cache() {
cu::allocator().clear_cache();
}
// Not supported in CUDA.
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include <mutex>
#include <set>
@@ -38,17 +39,24 @@ class CudaAllocator : public allocator::Allocator {
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
size_t get_cache_memory() const;
size_t set_cache_limit(size_t limit);
void clear_cache();
private:
CudaAllocator();
friend CudaAllocator& allocator();
void cuda_free(CudaBuffer* buf);
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
};

View File

@@ -30,141 +30,18 @@ void* Buffer::raw_ptr() {
namespace metal {
namespace {
BufferCache::BufferCache(ResidencySet& residency_set)
: head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool();
clear();
}
int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release();
n_release++;
}
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
// Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr;
auto it = buffer_pool_.lower_bound(size);
// Make sure we use most of the available memory
while (!pbuf && it != buffer_pool_.end() &&
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
// Collect from the cache
pbuf = it->second->buf;
// Remove from cache
remove_from_list(it->second);
delete it->second;
it = buffer_pool_.erase(it);
}
if (pbuf) {
pool_size_ -= pbuf->length();
}
return pbuf;
}
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
// Add to cache
if (buf) {
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
pool_size_ += buf->length();
buffer_pool_.insert({buf->length(), bh});
}
}
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release();
tail_->buf = nullptr;
n_release++;
}
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
if (!to_add)
return;
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
if (!to_remove) {
return;
}
// If in the middle
if (to_remove->prev && to_remove->next) {
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // If tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // If head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // If only element
head_ = nullptr;
tail_ = nullptr;
}
to_remove->prev = nullptr;
to_remove->next = nullptr;
}
} // namespace
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(residency_set_) {
buffer_cache_(
vm_page_size,
[](MTL::Buffer* buf) { return buf->length(); },
[this](MTL::Buffer* buf) {
if (!buf->heap()) {
residency_set_.erase(buf);
}
buf->release();
}) {
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =
@@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() {
if (heap_) {
heap_->release();
}
buffer_cache_.clear();
}
size_t MetalAllocator::set_cache_limit(size_t limit) {

View File

@@ -7,6 +7,7 @@
#include <vector>
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"
@@ -14,43 +15,6 @@ namespace mlx::core::metal {
using allocator::Buffer;
namespace {
class BufferCache {
public:
BufferCache(ResidencySet& residency_set);
~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
int release_cached_buffers(size_t min_bytes_to_free);
size_t cache_size() {
return pool_size_;
}
int clear();
private:
struct BufferHolder {
public:
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
BufferHolder* prev;
BufferHolder* next;
MTL::Buffer* buf;
};
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
ResidencySet& residency_set_;
};
} // namespace
class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */
public:
@@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator {
friend MetalAllocator& allocator();
// Caching allocator
BufferCache buffer_cache_;
BufferCache<MTL::Buffer> buffer_cache_;
ResidencySet residency_set_;

View File

@@ -14,11 +14,23 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
MLX_MTL_CONST int QUAD_SIZE = 4;
template <int bits, int wsize = 8>
inline constexpr short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
U sum = 0;
@@ -57,6 +69,21 @@ inline U load_vector(const device T* x, thread U* x_thread) {
}
}
else if (bits == 5) {
for (int i = 0; i < values_per_thread; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 32.0f;
x_thread[i + 2] = x[i + 2] / 4.0f;
x_thread[i + 3] = x[i + 3] / 128.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 2.0f;
x_thread[i + 6] = x[i + 6] / 64.0f;
x_thread[i + 7] = x[i + 7] / 8.0f;
}
}
else if (bits == 6) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
@@ -80,8 +107,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
U sum = 0;
@@ -121,6 +149,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
}
}
else if (bits == 5) {
for (int i = 0; i < N; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 32.0f;
x_thread[i + 2] = x[i + 2] / 4.0f;
x_thread[i + 3] = x[i + 3] / 128.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 2.0f;
x_thread[i + 6] = x[i + 6] / 64.0f;
x_thread[i + 7] = x[i + 7] / 8.0f;
}
}
else if (bits == 6) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
@@ -153,8 +196,9 @@ inline U qdot(
U bias,
U sum) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
U accum = 0;
@@ -199,6 +243,26 @@ inline U qdot(
}
}
else if (bits == 5) {
for (int i = 0; i < (values_per_thread / 8); i++) {
x_thread += 8 * i;
w += 5 * i;
accum += (w[0] & 0x1f) * x_thread[0];
accum += (w[0] & 0xe0) * x_thread[1];
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
accum += (w[1] & 0x7c) * x_thread[2];
accum += (w[1] & 0x80) * x_thread[3];
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
accum += (w[2] & 0xf0) * x_thread[4];
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
accum += (w[3] & 0x3e) * x_thread[5];
accum += (w[3] & 0xc0) * x_thread[6];
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
accum += (w[4] & 0xf8) * x_thread[7];
}
}
else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
x_thread += 4 * i;
@@ -234,8 +298,9 @@ inline U qdot_safe(
U sum,
int N) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
U accum = 0;
@@ -280,6 +345,26 @@ inline U qdot_safe(
}
}
else if (bits == 5) {
for (int i = 0; i < (N / 8); i++) {
x_thread += 8 * i;
w += 5 * i;
accum += (w[0] & 0x1f) * x_thread[0];
accum += (w[0] & 0xe0) * x_thread[1];
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
accum += (w[1] & 0x7c) * x_thread[2];
accum += (w[1] & 0x80) * x_thread[3];
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
accum += (w[2] & 0xf0) * x_thread[4];
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
accum += (w[3] & 0x3e) * x_thread[5];
accum += (w[3] & 0xc0) * x_thread[6];
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
accum += (w[4] & 0xf8) * x_thread[7];
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
x_thread += 4 * i;
@@ -310,8 +395,9 @@ template <typename U, int values_per_thread, int bits>
inline void
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
@@ -348,8 +434,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
}
}
} else if (bits == 6) {
else if (bits == 5) {
for (int i = 0; i < (values_per_thread / 8); i++) {
uint8_t w0 = w[5 * i];
uint8_t w1 = w[5 * i + 1];
uint8_t w2 = w[5 * i + 2];
uint8_t w3 = w[5 * i + 3];
uint8_t w4 = w[5 * i + 4];
result[8 * i] += x * ((w0 & 0x1f) * scale + bias);
result[8 * i + 1] +=
x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);
result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);
result[8 * i + 3] +=
x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);
result[8 * i + 4] +=
x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);
result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);
result[8 * i + 6] +=
x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);
result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);
}
}
else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
@@ -375,8 +484,9 @@ template <typename U, int N, int bits>
inline void
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
if (bits == 2) {
U s[4] = {
@@ -416,11 +526,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
}
}
else if (bits == 5) {
for (int i = 0; i < (N / 8); i++) {
w_local += 8 * i;
w += 5 * i;
w_local[0] = (w[0] & 0x1f) * scale + bias;
w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
w_local += 4 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x3f) * scale + bias;
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
@@ -452,11 +577,12 @@ struct QuantizedBlockLoader {
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
@@ -632,12 +758,11 @@ METAL_FUNC void qmv_fast_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int pack_factor = get_pack_factor<bits, 32>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
@@ -700,12 +825,12 @@ METAL_FUNC void qmv_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int pack_factor = get_pack_factor<bits, 32>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
@@ -857,8 +982,9 @@ METAL_FUNC void qvm_impl(
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
constexpr int pack_factor = get_pack_factor<bits, 32>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int tn = 32 / pack_factor;
constexpr int block_size = SIMD_SIZE;
@@ -981,9 +1107,10 @@ METAL_FUNC void qmm_t_impl(
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
@@ -1106,11 +1233,11 @@ METAL_FUNC void qmm_n_impl(
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
@@ -2120,11 +2247,10 @@ template <
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
using mma_t = mlx::steel::BlockMMA<
T,
@@ -2305,13 +2431,13 @@ template <typename T, const int group_size, const int bits>
constexpr float eps = 1e-7;
constexpr int simd_size = 32;
constexpr float n_bins = (1 << bits) - 1;
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_reduce = pack_factor / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
static_assert(
group_size % simd_size == 0,
@@ -2354,8 +2480,8 @@ template <typename T, const int group_size, const int bits>
biases[gindex] = static_cast<T>(bias);
}
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
uint32_t output = 0;
using OutType = metal::conditional_t<bits == 5, uint64_t, uint32_t>;
OutType output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
@@ -2363,27 +2489,35 @@ template <typename T, const int group_size, const int bits>
if (bits == 8) {
output = val;
} else {
output += val << (bits * (i % packs_per_int));
output |= val << (bits * (i % pack_factor));
}
if (packs_per_int < values_per_reduce &&
i % packs_per_int == packs_per_int - 1) {
out[out_index + i / packs_per_int] = output;
if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
out[out_index + i / pack_factor] = output;
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 1; j < writes_per_reduce; j++) {
uint8_t sval = simd_shuffle_down(val, j);
output += sval << (bits * (j * values_per_reduce + i));
output |= static_cast<OutType>(sval)
<< (bits * (j * values_per_reduce + i));
}
}
}
if (bits == 3 || bits == 6) {
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
}
} else if (bits == 5) {
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
out[out_index + 3] = (output & 0xff000000) >> 24;
out[out_index + 4] = (output & 0xff00000000) >> 32;
}
} else {
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
@@ -2399,12 +2533,11 @@ template <typename T, const int group_size, const int bits>
device T* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t oindex = offset * packs_per_int;
size_t oindex = offset * pack_factor;
size_t gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
@@ -2421,7 +2554,16 @@ template <typename T, const int group_size, const int bits>
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
} else if (bits == 5) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x1f) * scale + bias;
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
} else if (bits == 6) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x3f) * scale + bias;
@@ -2431,7 +2573,7 @@ template <typename T, const int group_size, const int bits>
} else {
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;

View File

@@ -136,6 +136,7 @@
instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(5) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)

View File

@@ -976,7 +976,9 @@ void fast::AffineQuantize::eval_gpu(
// Treat uint32 as uint8 in kernel
constexpr int uint8_per_uint32 = 4;
constexpr int simd_size = 32;
int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
size_t nthreads =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;

View File

@@ -839,14 +839,14 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
if (group_size != 32 && group_size != 64 && group_size != 128) {
std::ostringstream msg;
msg << "[quantize] The requested group size " << group_size
<< " is not supported. The supported group sizes are 64 and 128.";
<< " is not supported. The supported group sizes are 32, 64, and 128.";
throw std::invalid_argument(msg.str());
}
if (bits != 2 && bits != 3 && bits != 4 && bits != 6 && bits != 8) {
if (bits < 2 || bits > 8 || bits == 7) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 3, 4, 6 and 8.";
<< " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8.";
throw std::invalid_argument(msg.str());
}

View File

@@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
for gs in [32, 64, 128]:
for b in [2, 3, 6, 4, 8]:
for b in [2, 3, 5, 6, 4, 8]:
with self.subTest(gs=gs, b=b):
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
@@ -22,7 +22,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
for gs in [32, 64, 128]:
for b in [2, 3, 4, 6, 8]:
for b in [2, 3, 4, 5, 6, 8]:
w_q, scales, biases = mx.quantize(a, gs, b)
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
self.assertTrue(mx.all(a_hat == 0))
@@ -146,7 +146,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 6, 8], # bits
[2, 3, 4, 5, 6, 8], # bits
[256, 512, 67], # M
[64, 128], # N
[0, 1, 3, 8], # B
@@ -173,7 +173,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 6, 8], # bits
[2, 3, 4, 5, 6, 8], # bits
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B

View File

@@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(fy.shape, (4, 5, 6, 7))
def test_leaks(self):
gc.collect()
mx.synchronize()
if mx.metal.is_available():
mem_pre = mx.get_active_memory()
@@ -653,6 +654,7 @@ class TestVmap(mlx_tests.MLXTestCase):
outer()
gc.collect()
mx.synchronize()
if mx.metal.is_available():
mem_post = mx.get_active_memory()
else: