mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
12 Commits
4cbe605214
...
v0.26.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0408ba0a76 | ||
|
|
cbad6c3093 | ||
|
|
1b021f6984 | ||
|
|
95b7551d65 | ||
|
|
db5a7c6192 | ||
|
|
6ef2f67e7f | ||
|
|
f76ee1ffd2 | ||
|
|
54a71f270a | ||
|
|
55b4062dd8 | ||
|
|
79071bfba4 | ||
|
|
7774b87cbd | ||
|
|
35c87741cf |
@@ -231,6 +231,9 @@ target_include_directories(
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
|
||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "MLX"
|
||||
copyright = "2023, MLX Contributors"
|
||||
copyright = "2023, Apple"
|
||||
author = "MLX Contributors"
|
||||
version = ".".join(mx.__version__.split(".")[:3])
|
||||
release = version
|
||||
|
||||
157
mlx/backend/common/buffer_cache.h
Normal file
157
mlx/backend/common/buffer_cache.h
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(
|
||||
size_t page_size,
|
||||
std::function<size_t(T*)> get_size,
|
||||
std::function<void(T*)> free)
|
||||
: page_size_(page_size),
|
||||
get_size_(std::move(get_size)),
|
||||
free_(std::move(free)) {}
|
||||
|
||||
~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
BufferCache(const BufferCache&) = delete;
|
||||
BufferCache& operator=(const BufferCache&) = delete;
|
||||
|
||||
T* reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool.
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
if (it == buffer_pool_.end() ||
|
||||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Collect from the cache.
|
||||
T* buf = it->second->buf;
|
||||
pool_size_ -= it->first;
|
||||
|
||||
// Remove from record.
|
||||
remove_from_list(it->second);
|
||||
buffer_pool_.erase(it);
|
||||
return buf;
|
||||
}
|
||||
|
||||
void recycle_to_cache(T* buf) {
|
||||
assert(buf);
|
||||
// Add to cache.
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
size_t size = get_size_(buf);
|
||||
pool_size_ += size;
|
||||
buffer_pool_.emplace(size, bh);
|
||||
}
|
||||
|
||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
// Release buffer.
|
||||
size_t size = get_size_(tail_->buf);
|
||||
total_bytes_freed += size;
|
||||
free_(tail_->buf);
|
||||
n_release++;
|
||||
|
||||
// Remove from record.
|
||||
auto its = buffer_pool_.equal_range(size);
|
||||
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||
return el.second == tail_;
|
||||
});
|
||||
assert(it != buffer_pool_.end());
|
||||
buffer_pool_.erase(it);
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
int clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
free_(holder->buf);
|
||||
n_release++;
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
size_t cache_size() const {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
size_t page_size() const {
|
||||
return page_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||
|
||||
BufferHolder* prev{nullptr};
|
||||
BufferHolder* next{nullptr};
|
||||
T* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add) {
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void remove_from_list(BufferHolder* to_remove) {
|
||||
if (to_remove->prev && to_remove->next) { // if middle
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
delete to_remove;
|
||||
}
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_{nullptr};
|
||||
BufferHolder* tail_{nullptr};
|
||||
size_t pool_size_{0};
|
||||
|
||||
const size_t page_size_;
|
||||
std::function<size_t(T*)> get_size_;
|
||||
std::function<void(T*)> free_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,9 +1,16 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive) {
|
||||
std::ostringstream op_t;
|
||||
primitive->print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
@@ -101,4 +108,105 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||
}
|
||||
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == pow2) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||
// Dims with strides of 0 are ignored as they
|
||||
// correspond to broadcasted dimensions
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor) {
|
||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||
// divided by divisor.
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// No need to add this shape we can just remove it from the divisor.
|
||||
if (divisor % shape[i] == 0) {
|
||||
divisor /= shape[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
|
||||
if (divisor > 1) {
|
||||
if (grid_x % divisor == 0) {
|
||||
grid_x /= divisor;
|
||||
divisor = 1;
|
||||
} else if (grid_y % divisor == 0) {
|
||||
grid_y /= divisor;
|
||||
divisor = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -2,12 +2,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive);
|
||||
|
||||
inline int64_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
int64_t loc = 0;
|
||||
@@ -70,6 +73,28 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
// Compute the thread block dimensions which fit the given
|
||||
// input dimensions.
|
||||
// - The thread block dimensions will be powers of two
|
||||
// - The thread block size will be less than 2^pow2
|
||||
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
|
||||
// Computes a 2D grid where each element is < UINT_MAX
|
||||
// Assumes:
|
||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||
// - shape and strides correspond to a contiguous (no holes) but
|
||||
// possibly broadcasted array
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||
|
||||
// Same as above but we do an implicit division with divisor.
|
||||
// Basically, equivalent to factorizing
|
||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
|
||||
@@ -13,9 +13,18 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
|
||||
auto power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T, int bits>
|
||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
assert(bits == 3 || bits == 6);
|
||||
static_assert(bits == 3 || bits == 5 || bits == 6);
|
||||
if (bits == 3) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||
@@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||
} else if (bits == 5) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
|
||||
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
|
||||
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
|
||||
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
|
||||
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
|
||||
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
|
||||
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
|
||||
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
|
||||
|
||||
} else if (bits == 6) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||
w_out[1] =
|
||||
@@ -46,8 +65,8 @@ void _qmm(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -65,7 +84,7 @@ void _qmm(
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -104,8 +123,9 @@ void _qmm_t(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
|
||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -121,7 +141,7 @@ void _qmm_t(
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -304,6 +324,10 @@ void _qmm_dispatch_typed(
|
||||
_qmm_dispatch_group<T, 4>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 5:
|
||||
_qmm_dispatch_group<T, 5>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 6:
|
||||
_qmm_dispatch_group<T, 6>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
@@ -613,9 +637,8 @@ void quantize(
|
||||
float eps = 1e-7;
|
||||
|
||||
bool power_of_2_bits = is_power_of_2(bits);
|
||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
int el_per_int = get_pack_factor(bits, 32);
|
||||
int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||
size_t n_groups = w_size / group_size;
|
||||
|
||||
@@ -640,15 +663,21 @@ void quantize(
|
||||
}
|
||||
size_t out_idx = i * int_per_group;
|
||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||
uint32_t out_el = 0;
|
||||
uint64_t out_el = 0;
|
||||
for (int k = 0; k < el_per_int; ++k) {
|
||||
float w_el = w[w_idx + j * el_per_int + k];
|
||||
w_el = std::rint((w_el - bias) / scale);
|
||||
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||
out_el |= static_cast<uint64_t>(w_el) << (k * bits);
|
||||
}
|
||||
if (power_of_2_bits) {
|
||||
out[out_idx + j] = out_el;
|
||||
} else if (bits == 5) {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
|
||||
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
|
||||
} else {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
|
||||
@@ -11,13 +11,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
@@ -25,7 +24,7 @@ target_compile_options(mlx
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"75;80"
|
||||
"70;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@@ -13,24 +14,47 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
CudaAllocator::CudaAllocator() {
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
getpagesize(),
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// TODO: Check memory limit.
|
||||
auto* buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
// Find available buffer from cache.
|
||||
std::unique_lock lock(mutex_);
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache.
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
if (mem_required >= memory_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
|
||||
// Maintain the cache below the requested limit.
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
@@ -40,26 +64,14 @@ void CudaAllocator::free(Buffer buffer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([buffer]() { allocator().free(buffer); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
std::unique_lock lock(mutex_);
|
||||
active_memory_ -= buf->size;
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lock.unlock();
|
||||
cuda_free(buf);
|
||||
}
|
||||
|
||||
size_t size = buf->size;
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ -= size;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::size(Buffer buffer) const {
|
||||
@@ -98,6 +110,41 @@ size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||
return limit;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_cache_memory() const {
|
||||
return buffer_cache_.cache_size();
|
||||
}
|
||||
|
||||
size_t CudaAllocator::set_cache_limit(size_t limit) {
|
||||
std::lock_guard lk(mutex_);
|
||||
std::swap(limit, max_pool_size_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
void CudaAllocator::clear_cache() {
|
||||
std::lock_guard lk(mutex_);
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
|
||||
CudaAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
@@ -138,17 +185,19 @@ size_t set_memory_limit(size_t limit) {
|
||||
size_t get_memory_limit() {
|
||||
return cu::allocator().get_memory_limit();
|
||||
}
|
||||
|
||||
// TODO: Implement buffer cache.
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
return cu::allocator().get_cache_memory();
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return cu::allocator().set_cache_limit(limit);
|
||||
}
|
||||
void clear_cache() {
|
||||
cu::allocator().clear_cache();
|
||||
}
|
||||
|
||||
// Not supported in CUDA.
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
void clear_cache() {}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
@@ -38,17 +39,24 @@ class CudaAllocator : public allocator::Allocator {
|
||||
void reset_peak_memory();
|
||||
size_t get_memory_limit();
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_cache_memory() const;
|
||||
size_t set_cache_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t max_pool_size_;
|
||||
BufferCache<CudaBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
};
|
||||
|
||||
26
mlx/backend/cuda/kernel_utils.cu
Normal file
26
mlx/backend/cuda/kernel_utils.cu
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) {
|
||||
Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);
|
||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
||||
Dims dims = get_2d_grid_dims_common(shape, strides);
|
||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
dim3 get_2d_grid_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor) {
|
||||
Dims dims = get_2d_grid_dims_common(shape, strides, divisor);
|
||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +1,13 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file includes host-only utilies for writing CUDA kernels, the difference
|
||||
// from backend/cuda/kernels/utils.cuh is that the latter file only include
|
||||
// device-only code.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
@@ -32,4 +38,12 @@ struct CTypeToCudaType<complex64_t> {
|
||||
template <typename T>
|
||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||
|
||||
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
|
||||
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||
dim3 get_2d_grid_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/dtype_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
@@ -43,12 +43,29 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
});
|
||||
}
|
||||
|
||||
bool fast::ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
#define NO_GPU_USE_FALLBACK(func) \
|
||||
bool func::use_fallback(Stream s) { \
|
||||
return true; \
|
||||
} \
|
||||
NO_GPU_MULTI(func)
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
@@ -144,11 +161,11 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file include utilies that are used by C++ code (i.e. .cpp files).
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
@@ -30,141 +30,18 @@ void* Buffer::raw_ptr() {
|
||||
|
||||
namespace metal {
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(ResidencySet& residency_set)
|
||||
: head_(nullptr),
|
||||
tail_(nullptr),
|
||||
pool_size_(0),
|
||||
residency_set_(residency_set) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
clear();
|
||||
}
|
||||
|
||||
int BufferCache::clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf) {
|
||||
if (!holder->buf->heap()) {
|
||||
residency_set_.erase(holder->buf);
|
||||
}
|
||||
holder->buf->release();
|
||||
n_release++;
|
||||
}
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool
|
||||
MTL::Buffer* pbuf = nullptr;
|
||||
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
|
||||
// Make sure we use most of the available memory
|
||||
while (!pbuf && it != buffer_pool_.end() &&
|
||||
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
|
||||
// Collect from the cache
|
||||
pbuf = it->second->buf;
|
||||
|
||||
// Remove from cache
|
||||
remove_from_list(it->second);
|
||||
delete it->second;
|
||||
it = buffer_pool_.erase(it);
|
||||
}
|
||||
|
||||
if (pbuf) {
|
||||
pool_size_ -= pbuf->length();
|
||||
}
|
||||
|
||||
return pbuf;
|
||||
}
|
||||
|
||||
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||
// Add to cache
|
||||
if (buf) {
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
pool_size_ += buf->length();
|
||||
buffer_pool_.insert({buf->length(), bh});
|
||||
}
|
||||
}
|
||||
|
||||
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
if (tail_->buf) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
if (!tail_->buf->heap()) {
|
||||
residency_set_.erase(tail_->buf);
|
||||
}
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
n_release++;
|
||||
}
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
||||
if (!to_add)
|
||||
return;
|
||||
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
if (!to_remove) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If in the middle
|
||||
if (to_remove->prev && to_remove->next) {
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // If tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // If head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // If only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
to_remove->prev = nullptr;
|
||||
to_remove->next = nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
residency_set_(device_),
|
||||
buffer_cache_(residency_set_) {
|
||||
buffer_cache_(
|
||||
vm_page_size,
|
||||
[](MTL::Buffer* buf) { return buf->length(); },
|
||||
[this](MTL::Buffer* buf) {
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
buf->release();
|
||||
}) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||
auto max_rec_size =
|
||||
@@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() {
|
||||
if (heap_) {
|
||||
heap_->release();
|
||||
}
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
|
||||
@@ -14,43 +15,6 @@ namespace mlx::core::metal {
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
namespace {
|
||||
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(ResidencySet& residency_set);
|
||||
~BufferCache();
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
int release_cached_buffers(size_t min_bytes_to_free);
|
||||
size_t cache_size() {
|
||||
return pool_size_;
|
||||
}
|
||||
int clear();
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
|
||||
|
||||
BufferHolder* prev;
|
||||
BufferHolder* next;
|
||||
MTL::Buffer* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_;
|
||||
BufferHolder* tail_;
|
||||
size_t pool_size_;
|
||||
ResidencySet& residency_set_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
@@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator {
|
||||
friend MetalAllocator& allocator();
|
||||
|
||||
// Caching allocator
|
||||
BufferCache buffer_cache_;
|
||||
BufferCache<MTL::Buffer> buffer_cache_;
|
||||
|
||||
ResidencySet residency_set_;
|
||||
|
||||
|
||||
@@ -103,8 +103,8 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
vals[i] =
|
||||
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
@@ -134,10 +134,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
}
|
||||
if (lid == 0) {
|
||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,11 +14,23 @@ using namespace metal;
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
MLX_MTL_CONST int QUAD_SIZE = 4;
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr short get_pack_factor() {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr short get_bytes_per_pack() {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector(const device T* x, thread U* x_thread) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||
bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||
|
||||
U sum = 0;
|
||||
|
||||
@@ -57,6 +69,21 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 5) {
|
||||
for (int i = 0; i < values_per_thread; i += 8) {
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
||||
x[i + 6] + x[i + 7];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i + 1] = x[i + 1] / 32.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 4.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 128.0f;
|
||||
x_thread[i + 4] = x[i + 4] / 16.0f;
|
||||
x_thread[i + 5] = x[i + 5] / 2.0f;
|
||||
x_thread[i + 6] = x[i + 6] / 64.0f;
|
||||
x_thread[i + 7] = x[i + 7] / 8.0f;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 6) {
|
||||
for (int i = 0; i < values_per_thread; i += 4) {
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
@@ -80,8 +107,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||
bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||
|
||||
U sum = 0;
|
||||
|
||||
@@ -121,6 +149,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 5) {
|
||||
for (int i = 0; i < N; i += 8) {
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
||||
x[i + 6] + x[i + 7];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i + 1] = x[i + 1] / 32.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 4.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 128.0f;
|
||||
x_thread[i + 4] = x[i + 4] / 16.0f;
|
||||
x_thread[i + 5] = x[i + 5] / 2.0f;
|
||||
x_thread[i + 6] = x[i + 6] / 64.0f;
|
||||
x_thread[i + 7] = x[i + 7] / 8.0f;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 6) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
@@ -153,8 +196,9 @@ inline U qdot(
|
||||
U bias,
|
||||
U sum) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||
bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
@@ -199,6 +243,26 @@ inline U qdot(
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 5) {
|
||||
for (int i = 0; i < (values_per_thread / 8); i++) {
|
||||
x_thread += 8 * i;
|
||||
w += 5 * i;
|
||||
|
||||
accum += (w[0] & 0x1f) * x_thread[0];
|
||||
accum += (w[0] & 0xe0) * x_thread[1];
|
||||
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
||||
accum += (w[1] & 0x7c) * x_thread[2];
|
||||
accum += (w[1] & 0x80) * x_thread[3];
|
||||
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
||||
accum += (w[2] & 0xf0) * x_thread[4];
|
||||
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
||||
accum += (w[3] & 0x3e) * x_thread[5];
|
||||
accum += (w[3] & 0xc0) * x_thread[6];
|
||||
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
||||
accum += (w[4] & 0xf8) * x_thread[7];
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 6) {
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
x_thread += 4 * i;
|
||||
@@ -234,8 +298,9 @@ inline U qdot_safe(
|
||||
U sum,
|
||||
int N) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||
bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
@@ -280,6 +345,26 @@ inline U qdot_safe(
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 5) {
|
||||
for (int i = 0; i < (N / 8); i++) {
|
||||
x_thread += 8 * i;
|
||||
w += 5 * i;
|
||||
|
||||
accum += (w[0] & 0x1f) * x_thread[0];
|
||||
accum += (w[0] & 0xe0) * x_thread[1];
|
||||
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
||||
accum += (w[1] & 0x7c) * x_thread[2];
|
||||
accum += (w[1] & 0x80) * x_thread[3];
|
||||
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
||||
accum += (w[2] & 0xf0) * x_thread[4];
|
||||
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
||||
accum += (w[3] & 0x3e) * x_thread[5];
|
||||
accum += (w[3] & 0xc0) * x_thread[6];
|
||||
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
||||
accum += (w[4] & 0xf8) * x_thread[7];
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 6) {
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
x_thread += 4 * i;
|
||||
@@ -310,8 +395,9 @@ template <typename U, int values_per_thread, int bits>
|
||||
inline void
|
||||
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||
bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||
|
||||
if (bits == 2) {
|
||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||
@@ -348,8 +434,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
||||
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
} else if (bits == 6) {
|
||||
else if (bits == 5) {
|
||||
for (int i = 0; i < (values_per_thread / 8); i++) {
|
||||
uint8_t w0 = w[5 * i];
|
||||
uint8_t w1 = w[5 * i + 1];
|
||||
uint8_t w2 = w[5 * i + 2];
|
||||
uint8_t w3 = w[5 * i + 3];
|
||||
uint8_t w4 = w[5 * i + 4];
|
||||
result[8 * i] += x * ((w0 & 0x1f) * scale + bias);
|
||||
result[8 * i + 1] +=
|
||||
x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);
|
||||
result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);
|
||||
result[8 * i + 3] +=
|
||||
x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);
|
||||
result[8 * i + 4] +=
|
||||
x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);
|
||||
result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);
|
||||
result[8 * i + 6] +=
|
||||
x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);
|
||||
result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 6) {
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
uint8_t w0 = w[3 * i];
|
||||
uint8_t w1 = w[3 * i + 1];
|
||||
@@ -375,8 +484,9 @@ template <typename U, int N, int bits>
|
||||
inline void
|
||||
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||
bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||
|
||||
if (bits == 2) {
|
||||
U s[4] = {
|
||||
@@ -416,11 +526,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 5) {
|
||||
for (int i = 0; i < (N / 8); i++) {
|
||||
w_local += 8 * i;
|
||||
w += 5 * i;
|
||||
|
||||
w_local[0] = (w[0] & 0x1f) * scale + bias;
|
||||
w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
||||
w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
||||
w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
||||
w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
||||
w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
||||
w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
||||
w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 6) {
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
w_local += 4 * i;
|
||||
w += 3 * i;
|
||||
|
||||
w_local[0] = (w[0] & 0x3f) * scale + bias;
|
||||
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
||||
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
||||
@@ -452,11 +577,12 @@ struct QuantizedBlockLoader {
|
||||
group_size % BCOLS == 0,
|
||||
"The group size should be divisible by the columns");
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
||||
bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
||||
|
||||
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
|
||||
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
|
||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||
MLX_MTL_CONST short n_reads =
|
||||
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
||||
@@ -632,12 +758,11 @@ METAL_FUNC void qmv_fast_impl(
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
@@ -700,12 +825,12 @@ METAL_FUNC void qmv_impl(
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int packs_per_thread = 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
||||
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
@@ -857,8 +982,9 @@ METAL_FUNC void qvm_impl(
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||
|
||||
constexpr int tn = 32 / pack_factor;
|
||||
constexpr int block_size = SIMD_SIZE;
|
||||
|
||||
@@ -981,9 +1107,10 @@ METAL_FUNC void qmm_t_impl(
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = mlx::steel::
|
||||
@@ -1106,11 +1233,11 @@ METAL_FUNC void qmm_n_impl(
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = mlx::steel::
|
||||
@@ -2120,11 +2247,10 @@ template <
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
|
||||
using mma_t = mlx::steel::BlockMMA<
|
||||
T,
|
||||
@@ -2305,13 +2431,13 @@ template <typename T, const int group_size, const int bits>
|
||||
constexpr float eps = 1e-7;
|
||||
constexpr int simd_size = 32;
|
||||
constexpr float n_bins = (1 << bits) - 1;
|
||||
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||
constexpr int values_per_reduce = group_size / simd_size;
|
||||
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
||||
constexpr int writes_per_reduce = pack_factor / values_per_reduce;
|
||||
constexpr int writes_per_pack =
|
||||
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
|
||||
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
|
||||
static_assert(
|
||||
group_size % simd_size == 0,
|
||||
@@ -2354,8 +2480,8 @@ template <typename T, const int group_size, const int bits>
|
||||
biases[gindex] = static_cast<T>(bias);
|
||||
}
|
||||
|
||||
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
|
||||
uint32_t output = 0;
|
||||
using OutType = metal::conditional_t<bits == 5, uint64_t, uint32_t>;
|
||||
OutType output = 0;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_reduce; i++) {
|
||||
@@ -2363,27 +2489,35 @@ template <typename T, const int group_size, const int bits>
|
||||
if (bits == 8) {
|
||||
output = val;
|
||||
} else {
|
||||
output += val << (bits * (i % packs_per_int));
|
||||
output |= val << (bits * (i % pack_factor));
|
||||
}
|
||||
|
||||
if (packs_per_int < values_per_reduce &&
|
||||
i % packs_per_int == packs_per_int - 1) {
|
||||
out[out_index + i / packs_per_int] = output;
|
||||
if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
|
||||
out[out_index + i / pack_factor] = output;
|
||||
output = 0;
|
||||
} else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 1; j < writes_per_reduce; j++) {
|
||||
uint8_t sval = simd_shuffle_down(val, j);
|
||||
output += sval << (bits * (j * values_per_reduce + i));
|
||||
output |= static_cast<OutType>(sval)
|
||||
<< (bits * (j * values_per_reduce + i));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bits == 3 || bits == 6) {
|
||||
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
|
||||
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
||||
out[out_index] = output & 0xff;
|
||||
out[out_index + 1] = (output & 0xff00) >> 8;
|
||||
out[out_index + 2] = (output & 0xff0000) >> 16;
|
||||
}
|
||||
} else if (bits == 5) {
|
||||
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
||||
out[out_index] = output & 0xff;
|
||||
out[out_index + 1] = (output & 0xff00) >> 8;
|
||||
out[out_index + 2] = (output & 0xff0000) >> 16;
|
||||
out[out_index + 3] = (output & 0xff000000) >> 24;
|
||||
out[out_index + 4] = (output & 0xff00000000) >> 32;
|
||||
}
|
||||
} else {
|
||||
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
||||
out[out_index / writes_per_reduce] = output;
|
||||
@@ -2399,12 +2533,11 @@ template <typename T, const int group_size, const int bits>
|
||||
device T* out [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
size_t oindex = offset * packs_per_int;
|
||||
size_t oindex = offset * pack_factor;
|
||||
size_t gindex = oindex / group_size;
|
||||
T scale = scales[gindex];
|
||||
T bias = biases[gindex];
|
||||
@@ -2421,7 +2554,16 @@ template <typename T, const int group_size, const int bits>
|
||||
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
||||
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
||||
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
||||
|
||||
} else if (bits == 5) {
|
||||
w += offset * bytes_per_pack;
|
||||
out[0] = (w[0] & 0x1f) * scale + bias;
|
||||
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
||||
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
||||
out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
||||
out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
||||
out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
||||
out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
||||
out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
||||
} else if (bits == 6) {
|
||||
w += offset * bytes_per_pack;
|
||||
out[0] = (w[0] & 0x3f) * scale + bias;
|
||||
@@ -2431,7 +2573,7 @@ template <typename T, const int group_size, const int bits>
|
||||
} else {
|
||||
uint val = w[offset];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < packs_per_int; i++) {
|
||||
for (int i = 0; i < pack_factor; i++) {
|
||||
uint8_t d;
|
||||
if (bits == 2) {
|
||||
d = (val >> (bits * i)) & 0x03;
|
||||
|
||||
@@ -136,6 +136,7 @@
|
||||
instantiate_quantized_groups(2) \
|
||||
instantiate_quantized_groups(3) \
|
||||
instantiate_quantized_groups(4) \
|
||||
instantiate_quantized_groups(5) \
|
||||
instantiate_quantized_groups(6) \
|
||||
instantiate_quantized_groups(8)
|
||||
|
||||
|
||||
@@ -128,8 +128,8 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
vals[i] =
|
||||
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
|
||||
@@ -10,6 +10,10 @@
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
bool RMSNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
bool LayerNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
||||
@@ -976,7 +976,9 @@ void fast::AffineQuantize::eval_gpu(
|
||||
// Treat uint32 as uint8 in kernel
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
constexpr int simd_size = 32;
|
||||
int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_;
|
||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
||||
: bits_ == 6 ? 4
|
||||
: 8 / bits_;
|
||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
|
||||
size_t nthreads =
|
||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||
|
||||
@@ -7,6 +7,10 @@ namespace mlx::core::fast {
|
||||
|
||||
constexpr int n_per_thread = 4;
|
||||
|
||||
bool RoPE::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
@@ -339,6 +339,46 @@ void sdpa_vector_2pass(
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int value_head_dim = v.shape(-1);
|
||||
const int query_head_dim = q.shape(-1);
|
||||
const int query_sequence_length = q.shape(2);
|
||||
const int key_sequence_length = k.shape(2);
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
|
||||
query_head_dim == 256);
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
const bool supports_sdpa_full =
|
||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
|
||||
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||
(query_sequence_length <= key_sequence_length) &&
|
||||
sdpa_vector_supported_head_dim;
|
||||
|
||||
return !(supports_sdpa_full || supports_sdpa_vector);
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
using namespace mlx;
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -59,109 +58,20 @@ std::string type_to_name(const array& a) {
|
||||
return type_to_name(a.dtype());
|
||||
}
|
||||
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == pow2) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) {
|
||||
Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);
|
||||
return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
||||
// Dims with strides of 0 are ignored as they
|
||||
// correspond to broadcasted dimensions
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return MTL::Size(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
Dims dims = get_2d_grid_dims_common(shape, strides);
|
||||
return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
MTL::Size
|
||||
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) {
|
||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||
// divided by divisor.
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// No need to add this shape we can just remove it from the divisor.
|
||||
if (divisor % shape[i] == 0) {
|
||||
divisor /= shape[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
|
||||
if (divisor > 1) {
|
||||
if (grid_x % divisor == 0) {
|
||||
grid_x /= divisor;
|
||||
divisor = 1;
|
||||
} else if (grid_y % divisor == 0) {
|
||||
grid_y /= divisor;
|
||||
divisor = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return MTL::Size(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive) {
|
||||
std::ostringstream op_t;
|
||||
primitive->print(op_t);
|
||||
return op_t.str();
|
||||
Dims dims = get_2d_grid_dims_common(shape, strides, divisor);
|
||||
return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -13,22 +13,9 @@ namespace mlx::core {
|
||||
std::string type_to_name(const Dtype& t);
|
||||
std::string type_to_name(const array& a);
|
||||
|
||||
// Compute the thread block dimensions which fit the given
|
||||
// input dimensions.
|
||||
// - The thread block dimensions will be powers of two
|
||||
// - The thread block size will be less than 2^pow2
|
||||
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
|
||||
// Computes a 2D grid where each element is < UINT_MAX
|
||||
// Assumes:
|
||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||
// - shape and strides correspond to a contiguous (no holes) but
|
||||
// possibly broadcasted array
|
||||
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||
|
||||
// Same as above but we do an implicit division with divisor.
|
||||
// Basically, equivalent to factorizing
|
||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||
MTL::Size
|
||||
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor);
|
||||
|
||||
@@ -58,8 +45,6 @@ inline void debug_set_primitive_buffer_label(
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive);
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
|
||||
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
|
||||
|
||||
@@ -10,6 +10,12 @@
|
||||
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||
}
|
||||
|
||||
#define NO_GPU_USE_FALLBACK(func) \
|
||||
bool func::use_fallback(Stream s) { \
|
||||
return true; \
|
||||
} \
|
||||
NO_GPU_MULTI(func)
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||
@@ -17,6 +23,17 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
bool fast::ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
NO_GPU(Abs)
|
||||
NO_GPU(Add)
|
||||
NO_GPU(AddMM)
|
||||
@@ -130,11 +147,11 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU(View)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
||||
44
mlx/fast.cpp
44
mlx/fast.cpp
@@ -9,7 +9,6 @@
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
@@ -112,7 +111,8 @@ array rms_norm(
|
||||
|
||||
auto passed_weight =
|
||||
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
|
||||
if (s.device == Device::gpu) {
|
||||
|
||||
if (!RMSNorm::use_fallback(s)) {
|
||||
return array(
|
||||
x.shape(),
|
||||
out_type,
|
||||
@@ -256,7 +256,7 @@ array layer_norm(
|
||||
auto passed_bias =
|
||||
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
if (!LayerNorm::use_fallback(s)) {
|
||||
return array(
|
||||
x.shape(),
|
||||
out_type,
|
||||
@@ -470,7 +470,7 @@ array rope(
|
||||
}
|
||||
};
|
||||
auto stream = to_stream(s);
|
||||
if (stream.device == Device::gpu) {
|
||||
if (!RoPE::use_fallback(stream)) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
@@ -727,31 +727,6 @@ array scaled_dot_product_attention(
|
||||
};
|
||||
|
||||
auto stream = to_stream(s);
|
||||
const int value_head_dim = v.shape(-1);
|
||||
const int query_head_dim = q.shape(-1);
|
||||
const int query_sequence_length = q.shape(2);
|
||||
const int key_sequence_length = k.shape(2);
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
|
||||
query_head_dim == 256);
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
const bool supports_sdpa_full = sdpa_full_supported_mask &&
|
||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||
(query_sequence_length <= key_sequence_length) &&
|
||||
sdpa_vector_supported_head_dim && stream.device == Device::gpu;
|
||||
|
||||
const bool implementation_supports_use_case =
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
if (has_arr_mask) {
|
||||
// Check type
|
||||
@@ -770,7 +745,8 @@ array scaled_dot_product_attention(
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||
}
|
||||
if (!detail::in_grad_tracing() && implementation_supports_use_case) {
|
||||
if (!ScaledDotProductAttention::use_fallback(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
@@ -779,7 +755,7 @@ array scaled_dot_product_attention(
|
||||
stream, fallback, scale, do_causal),
|
||||
std::move(inputs));
|
||||
}
|
||||
return fallback(inputs)[0];
|
||||
return fallback(std::move(inputs))[0];
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
@@ -839,14 +815,14 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested group size " << group_size
|
||||
<< " is not supported. The supported group sizes are 64 and 128.";
|
||||
<< " is not supported. The supported group sizes are 32, 64, and 128.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (bits != 2 && bits != 3 && bits != 4 && bits != 6 && bits != 8) {
|
||||
if (bits < 2 || bits > 8 || bits == 7) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested number of bits " << bits
|
||||
<< " is not supported. The supported bits are 2, 3, 4, 6 and 8.";
|
||||
<< " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,8 @@ class RMSNorm : public Custom {
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
|
||||
static bool use_fallback(Stream stream);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
@@ -65,7 +67,6 @@ class RMSNorm : public Custom {
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
@@ -91,7 +92,6 @@ class RMSNormVJP : public Custom {
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
@@ -103,6 +103,8 @@ class LayerNorm : public Custom {
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
|
||||
static bool use_fallback(Stream s);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
@@ -124,7 +126,6 @@ class LayerNorm : public Custom {
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
@@ -150,7 +151,6 @@ class LayerNormVJP : public Custom {
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
@@ -171,6 +171,8 @@ class RoPE : public Custom {
|
||||
scale_(scale),
|
||||
forward_(forward) {}
|
||||
|
||||
static bool use_fallback(Stream s);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
@@ -193,7 +195,6 @@ class RoPE : public Custom {
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
int dims_;
|
||||
bool traditional_;
|
||||
float base_;
|
||||
@@ -210,6 +211,15 @@ class ScaledDotProductAttention : public Custom {
|
||||
const bool do_causal)
|
||||
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
|
||||
|
||||
static bool use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
@@ -230,7 +240,6 @@ class ScaledDotProductAttention : public Custom {
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float scale_;
|
||||
bool do_causal_;
|
||||
};
|
||||
@@ -263,7 +272,6 @@ class AffineQuantize : public Custom {
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool dequantize_;
|
||||
|
||||
25
mlx/ops.cpp
25
mlx/ops.cpp
@@ -2862,21 +2862,30 @@ array matmul(
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
// Complex matmul in terms of real matmuls
|
||||
if (out_type == complex64) {
|
||||
|
||||
// complex matmul using Karatsuba's Algorithm
|
||||
if (a.dtype() == complex64 || b.dtype() == complex64) {
|
||||
// Extract real and imaginary parts
|
||||
auto a_real = real(a, s);
|
||||
auto b_real = real(b, s);
|
||||
auto a_imag = imag(a, s);
|
||||
auto b_real = real(b, s);
|
||||
auto b_imag = imag(b, s);
|
||||
auto c_real =
|
||||
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s);
|
||||
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s);
|
||||
|
||||
// Compute real and imaginary components of the result
|
||||
auto m1 = matmul(a_real, b_real, s);
|
||||
auto m2 = matmul(a_imag, b_imag, s);
|
||||
auto m3 = matmul(add(a_real, a_imag, s), add(b_real, b_imag, s), s);
|
||||
|
||||
auto c_real = subtract(m1, m2, s);
|
||||
auto c_imag = subtract(m3, add(m1, m2, s), s);
|
||||
|
||||
return add(
|
||||
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
|
||||
@@ -208,9 +208,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
// output arrays stream
|
||||
fences[it->second].wait(stream, in);
|
||||
} else if (in.event().valid()) {
|
||||
if (in.event().is_signaled()) {
|
||||
in.detach_event();
|
||||
} else if (in.event().stream() != stream) {
|
||||
if (in.event().stream() != stream) {
|
||||
// Use event to wait across async eval
|
||||
in.event().wait(stream);
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
#pragma once
|
||||
|
||||
#define MLX_VERSION_MAJOR 0
|
||||
#define MLX_VERSION_MINOR 25
|
||||
#define MLX_VERSION_PATCH 2
|
||||
#define MLX_VERSION_MINOR 26
|
||||
#define MLX_VERSION_PATCH 0
|
||||
#define MLX_VERSION_NUMERIC \
|
||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||
|
||||
|
||||
@@ -1210,13 +1210,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(c, c_np))
|
||||
|
||||
# Test addmm
|
||||
M = 16
|
||||
K = 50
|
||||
N = 32
|
||||
|
||||
def rand(shape):
|
||||
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||
|
||||
a = rand((M, K))
|
||||
b = rand((K, N))
|
||||
c = rand((M, N))
|
||||
@@ -1224,6 +1217,13 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
# complex with real
|
||||
a = rand((M, K)).real
|
||||
b = rand((K, N))
|
||||
c = mx.matmul(a, b)
|
||||
c_np = np.matmul(a, b)
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_quantize_dequantize(self):
|
||||
w = mx.random.normal(shape=(128, 512))
|
||||
for gs in [32, 64, 128]:
|
||||
for b in [2, 3, 6, 4, 8]:
|
||||
for b in [2, 3, 5, 6, 4, 8]:
|
||||
with self.subTest(gs=gs, b=b):
|
||||
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||
@@ -22,7 +22,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
for gs in [32, 64, 128]:
|
||||
for b in [2, 3, 4, 6, 8]:
|
||||
for b in [2, 3, 4, 5, 6, 8]:
|
||||
w_q, scales, biases = mx.quantize(a, gs, b)
|
||||
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||
self.assertTrue(mx.all(a_hat == 0))
|
||||
@@ -146,7 +146,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 3, 4, 6, 8], # bits
|
||||
[2, 3, 4, 5, 6, 8], # bits
|
||||
[256, 512, 67], # M
|
||||
[64, 128], # N
|
||||
[0, 1, 3, 8], # B
|
||||
@@ -173,7 +173,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 3, 4, 6, 8], # bits
|
||||
[2, 3, 4, 5, 6, 8], # bits
|
||||
[32, 128, 256], # M
|
||||
[128, 256, 67], # N
|
||||
[0, 1, 3, 8], # B
|
||||
|
||||
@@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||
|
||||
def test_leaks(self):
|
||||
gc.collect()
|
||||
mx.synchronize()
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.get_active_memory()
|
||||
@@ -653,6 +654,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
mx.synchronize()
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
|
||||
@@ -1036,6 +1036,9 @@ TEST_CASE("test reduction ops") {
|
||||
x = array({-inf, -inf});
|
||||
CHECK_EQ(logsumexp(x).item<float>(), -inf);
|
||||
|
||||
x = repeat(array(-inf), 5000);
|
||||
CHECK_EQ(logsumexp(x).item<float>(), -inf);
|
||||
|
||||
x = array({0.0f, -inf});
|
||||
CHECK_EQ(logsumexp(x).item<float>(), 0.0f);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user