mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
Compare commits
14 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
726dbd9267 | ||
![]() |
54f05e7195 | ||
![]() |
26be608470 | ||
![]() |
248431eb3c | ||
![]() |
76f275b4df | ||
![]() |
f1951d6cce | ||
![]() |
62f297b51d | ||
![]() |
09bc32f62f | ||
![]() |
46d8b16ab4 | ||
![]() |
42533931fa | ||
![]() |
9bd3a7102f | ||
![]() |
9e516b71ea | ||
![]() |
eac961ddb1 | ||
![]() |
57c6aa7188 |
@@ -349,7 +349,7 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_documentation:
|
||||
@@ -386,7 +386,7 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
weekly_build:
|
||||
when:
|
||||
@@ -397,7 +397,7 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
@@ -409,5 +409,5 @@ workflows:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.19.2)
|
||||
set(MLX_VERSION 0.20.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
@@ -144,6 +144,13 @@ def reduction(op, axis, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
mx.eval(z)
|
||||
|
||||
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
@@ -505,5 +512,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
@@ -1,3 +1,5 @@
|
||||
.. _custom_metal_kernels:
|
||||
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
@@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
|
@@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
||||
ys = mx.random.uniform(shape=(100, 4096))
|
||||
|
||||
def naive_add(xs, ys):
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
||||
|
||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
@@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
# Vectorize over the second dimension of x and the
|
||||
# first dimension of y
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
||||
|
||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||
|
@@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.
|
||||
|
||||
Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which outputs
|
||||
general, MLX has limited support for operations for which output
|
||||
*shapes* are dependent on input *data*. Other examples of these types of
|
||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||
single input version of :func:`numpy.where`.
|
||||
|
@@ -109,7 +109,7 @@ Here is a concrete example:
|
||||
|
||||
An important behavior to be aware of is when the graph will be implicitly
|
||||
evaluated. Anytime you ``print`` an array, convert it to an
|
||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||
saving functions) will also evaluate the array.
|
||||
|
||||
|
@@ -271,6 +271,9 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
for (array& a : ad.inputs) {
|
||||
if (a.array_desc_) {
|
||||
input_map.insert({a.id(), a});
|
||||
for (auto& s : a.siblings()) {
|
||||
input_map.insert({s.id(), s});
|
||||
}
|
||||
}
|
||||
}
|
||||
ad.inputs.clear();
|
||||
|
@@ -136,13 +136,8 @@ void CommandEncoder::set_input_array(
|
||||
int64_t offset /* = 0 */) {
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
||||
// Insert a barrier
|
||||
enc_->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs_.erase(it);
|
||||
}
|
||||
needs_barrier_ =
|
||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
@@ -161,19 +156,32 @@ void CommandEncoder::set_output_array(
|
||||
if (concurrent_) {
|
||||
concurrent_outputs_.insert(buf);
|
||||
} else {
|
||||
outputs_.insert(buf);
|
||||
next_outputs_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::maybeInsertBarrier() {
|
||||
if (needs_barrier_) {
|
||||
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
|
||||
needs_barrier_ = false;
|
||||
prev_outputs_ = std::move(next_outputs_);
|
||||
} else {
|
||||
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
|
||||
}
|
||||
next_outputs_.clear();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@@ -49,7 +49,7 @@ struct CommandEncoder {
|
||||
}
|
||||
~ConcurrentContext() {
|
||||
enc.concurrent_ = false;
|
||||
enc.outputs_.insert(
|
||||
enc.prev_outputs_.insert(
|
||||
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
|
||||
enc.concurrent_outputs_.clear();
|
||||
}
|
||||
@@ -66,6 +66,7 @@ struct CommandEncoder {
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void maybeInsertBarrier();
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
@@ -84,8 +85,10 @@ struct CommandEncoder {
|
||||
|
||||
private:
|
||||
MTL::ComputeCommandEncoder* enc_;
|
||||
bool needs_barrier_{false};
|
||||
bool concurrent_{false};
|
||||
std::unordered_set<MTL::Resource*> outputs_;
|
||||
std::unordered_set<MTL::Resource*> prev_outputs_;
|
||||
std::unordered_set<MTL::Resource*> next_outputs_;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
||||
std::unordered_set<const void*> all_inputs_;
|
||||
std::unordered_set<const void*> all_outputs_;
|
||||
|
@@ -319,16 +319,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
std::string op_type = op_name(out);
|
||||
op_type[0] = std::toupper(op_name(out)[0]);
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, "init_reduce", out_type, op);
|
||||
kernel_name, func_name, out_type, op);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@@ -79,6 +79,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
|
@@ -650,8 +650,8 @@ METAL_FUNC void qvm_impl(
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
const int in_vec_size,
|
||||
const int out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -1298,6 +1298,61 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, int split_k = 32>
|
||||
[[kernel]] void qvm_split_k(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int& final_block_size [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
|
||||
// When (in_vec_size % split_k != 0) the final block needs to be smaller
|
||||
int in_vec_size_adj =
|
||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||
|
||||
qvm_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size_adj,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
|
@@ -51,6 +51,15 @@
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||
@@ -84,11 +93,16 @@
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_quad(type, group_size, bits)
|
||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
|
@@ -113,9 +113,12 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
@@ -123,9 +126,14 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
|
||||
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 0) \
|
||||
|
@@ -1,11 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
template <typename T, typename U, typename Op, int NDIMS>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@@ -20,170 +15,128 @@ template <
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 tsize [[threads_per_grid]]) {
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
constexpr int n_reads = 4;
|
||||
Op op;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
|
||||
// Case 1: Small row small column
|
||||
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
|
||||
U totals[31];
|
||||
for (int i = 0; i < 31; i++) {
|
||||
totals[i] = Op::init;
|
||||
U totals[n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
|
||||
if (column >= reduction_stride) {
|
||||
return;
|
||||
}
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total_rows = non_col_reductions * reduction_size;
|
||||
loop.next(lid.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = lid.y; r < total_rows; r += lsize.y) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
vals[i] =
|
||||
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
loop.next(lsize.y, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
short stride = reduction_stride;
|
||||
short size = reduction_size;
|
||||
short blocks = stride / N_READS;
|
||||
short extra = stride - blocks * N_READS;
|
||||
|
||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint r = 0; r < non_col_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < blocks; j++) {
|
||||
for (short k = 0; k < N_READS; k++) {
|
||||
totals[j * N_READS + k] =
|
||||
op(totals[j * N_READS + k],
|
||||
static_cast<U>(row[i * stride + j * N_READS + k]));
|
||||
}
|
||||
}
|
||||
for (short k = 0; k < extra; k++) {
|
||||
totals[blocks * N_READS + k] =
|
||||
op(totals[blocks * N_READS + k],
|
||||
static_cast<U>(row[i * stride + blocks * N_READS + k]));
|
||||
if (lsize.y > 1) {
|
||||
// lsize.y should be <= 8
|
||||
threadgroup U shared_vals[32 * 8 * n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (lid.y == 0) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = shared_vals[lid.x * n_reads + i];
|
||||
}
|
||||
for (uint j = 1; j < lsize.y; j++) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] =
|
||||
op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
|
||||
totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
out += out_idx * reduction_stride;
|
||||
for (short j = 0; j < stride; j++) {
|
||||
out[j] = totals[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: Long row small column
|
||||
else if (reduction_size * non_col_reductions < 32) {
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
short size = reduction_size;
|
||||
size_t offset = size_t(tid.x) * N_READS;
|
||||
bool safe = offset + N_READS <= reduction_stride;
|
||||
short extra = reduction_stride - offset;
|
||||
|
||||
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
|
||||
|
||||
for (uint r = 0; r < non_col_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
if (safe) {
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < N_READS; j++) {
|
||||
totals[j] =
|
||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < extra; j++) {
|
||||
totals[j] =
|
||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
out += out_idx * reduction_stride + offset;
|
||||
if (lid.y == 0) {
|
||||
out += out_idx * reduction_stride + column;
|
||||
if (safe) {
|
||||
for (short i = 0; i < N_READS; i++) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < extra; i++) {
|
||||
for (int i = 0; column + i < reduction_stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 3: Long row medium column
|
||||
else {
|
||||
threadgroup U shared_vals[1024];
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
short stride = reduction_stride;
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 tile((stride + N_READS - 1) / N_READS, 32);
|
||||
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
|
||||
short sm_stride = tile.x * N_READS;
|
||||
bool safe = offset.x + N_READS <= stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
|
||||
|
||||
// Read cooperatively and contiguously and aggregate the partial results.
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y; r < total; r += simd_size) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// Each thread holds N_READS partial results but the simdgroups are not
|
||||
// aligned to do the reduction across the simdgroup so we write our results
|
||||
// in the shared memory and read them back according to the simdgroup.
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op.simd_reduce(
|
||||
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
short column = simd_group_id * N_READS;
|
||||
out += out_idx * reduction_stride + column;
|
||||
if (column + N_READS <= stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; column + i < stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
template <typename T, typename U, typename Op, int NDIMS>
|
||||
[[kernel]] void col_reduce_longcolumn(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
const constant size_t& out_size [[buffer(11)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
Op op;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
|
||||
size_t out_idx = gid.x + gsize.x * size_t(gid.y);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
in += in_idx + lid.x;
|
||||
|
||||
U total = Op::init;
|
||||
size_t total_rows = non_col_reductions * reduction_size;
|
||||
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
|
||||
r += lsize.y * gsize.z) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
total = op(static_cast<U>(*row), total);
|
||||
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
threadgroup U shared_vals[32 * 32];
|
||||
shared_vals[lid.y * lsize.x + lid.x] = total;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (lid.y == 0) {
|
||||
for (uint i = 1; i < lsize.y; i++) {
|
||||
total = op(total, shared_vals[i * lsize.x + lid.x]);
|
||||
}
|
||||
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,7 +169,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
constexpr int n_simdgroups = 4;
|
||||
constexpr int n_simdgroups = 8;
|
||||
constexpr short tgp_size = n_simdgroups * simd_size;
|
||||
constexpr short n_reads = (BM * BN) / tgp_size;
|
||||
constexpr short n_read_blocks = BN / n_reads;
|
||||
@@ -329,3 +282,103 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
[[kernel]] void col_reduce_2pass(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
const constant size_t& out_size [[buffer(11)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
constexpr int n_simdgroups = 8;
|
||||
constexpr short tgp_size = n_simdgroups * simd_size;
|
||||
constexpr short n_reads = (BM * BN) / tgp_size;
|
||||
constexpr short n_read_blocks = BN / n_reads;
|
||||
constexpr int n_outputs = BN / n_simdgroups;
|
||||
constexpr short outer_blocks = 32;
|
||||
static_assert(BM == 32, "BM should be equal to 32");
|
||||
|
||||
threadgroup U shared_vals[BN * BM];
|
||||
U totals[n_reads];
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||
size_t column = BN * gid.x + offset.x;
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t block_idx = full_idx / out_size;
|
||||
size_t out_idx = full_idx % out_size;
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y + block_idx * BM; r < total;
|
||||
r += outer_blocks * BM) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
vals[i] =
|
||||
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// We can use a simd reduction to accumulate across BM so each thread writes
|
||||
// the partial output to SM and then each simdgroup does BN / n_simdgroups
|
||||
// accumulations.
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
shared_vals[offset.y * BN + offset.x + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] =
|
||||
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
size_t out_column = BN * gid.x + out_offset.x;
|
||||
out += full_idx * reduction_stride + out_column;
|
||||
if (out_column + n_outputs <= reduction_stride) {
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; out_column + i < reduction_stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -936,6 +936,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
|
||||
const constant int& gqa_factor, \
|
||||
const constant int& N, \
|
||||
const constant size_t& k_stride, \
|
||||
const constant size_t& v_stride, \
|
||||
const constant float& scale, \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
|
@@ -13,6 +13,7 @@ template <typename T, int D>
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -38,7 +39,7 @@ template <typename T, int D>
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
|
@@ -97,6 +97,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -148,6 +149,125 @@ void launch_qmm(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void qvm_split_k(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int group_size,
|
||||
int bits,
|
||||
int D,
|
||||
int O,
|
||||
int B,
|
||||
int N,
|
||||
const Stream& s) {
|
||||
int split_k = D > 8192 ? 32 : 8;
|
||||
int split_D = (D + split_k - 1) / split_k;
|
||||
N *= split_k;
|
||||
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
// Ensure that the last two dims are row contiguous.
|
||||
// TODO: Check if we really need this for x as well...
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
auto x_shape = x.shape();
|
||||
auto x_strides = x.strides();
|
||||
int w_batch_ndims = w.ndim() - 2;
|
||||
auto w_shape = w.shape();
|
||||
auto w_strides = w.strides();
|
||||
auto s_strides = scales.strides();
|
||||
auto b_strides = biases.strides();
|
||||
|
||||
// Add split_k dim with reshapes
|
||||
x_shape.insert(x_shape.end() - 2, split_k);
|
||||
x_shape.back() /= split_k;
|
||||
x_strides.insert(x_strides.end() - 2, split_D);
|
||||
x_strides[x.ndim() - 1] = split_D;
|
||||
x_batch_ndims += 1;
|
||||
|
||||
w_shape.insert(w_shape.end() - 2, split_k);
|
||||
w_shape[w.ndim() - 1] /= split_k;
|
||||
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
|
||||
w_batch_ndims += 1;
|
||||
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
|
||||
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
|
||||
|
||||
int final_block_size = D - (split_k - 1) * split_D;
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto temp_shape = out.shape();
|
||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_"
|
||||
<< bits << "_spk_" << split_k;
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qvm_split_k", type_string, group_size, bits, split_k);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(intermediate, 4);
|
||||
compute_encoder->setBytes(&split_D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
|
||||
set_vector_bytes(compute_encoder, x_shape, 8);
|
||||
set_vector_bytes(compute_encoder, x_strides, 9);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, w_shape, 11);
|
||||
set_vector_bytes(compute_encoder, w_strides, 12);
|
||||
set_vector_bytes(compute_encoder, s_strides, 13);
|
||||
set_vector_bytes(compute_encoder, b_strides, 14);
|
||||
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
int axis = intermediate.ndim() - 3;
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce,
|
||||
{intermediate.shape(axis)},
|
||||
{intermediate.strides(axis)});
|
||||
strided_reduce_general_dispatch(
|
||||
intermediate, out, "sum", plan, {axis}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
void qmm_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
@@ -211,7 +331,9 @@ void qmm_op(
|
||||
aligned = true;
|
||||
}
|
||||
} else {
|
||||
if (B < 4) {
|
||||
if (B < 4 && D >= 1024 && !gather) {
|
||||
return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s);
|
||||
} else if (B < 4) {
|
||||
name += "qvm";
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
|
@@ -141,6 +141,20 @@ struct ColReduceArgs {
|
||||
ndim = shape.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the col reduce arguments for reducing the 1st axis of the row
|
||||
* contiguous intermediate array.
|
||||
*/
|
||||
ColReduceArgs(const array& intermediate) {
|
||||
assert(intermediate.flags().row_contiguous);
|
||||
|
||||
reduction_size = intermediate.shape(0);
|
||||
reduction_stride = intermediate.size() / reduction_size;
|
||||
non_col_reductions = 1;
|
||||
reduce_ndim = 0;
|
||||
ndim = 0;
|
||||
}
|
||||
|
||||
void encode(CommandEncoder& compute_encoder) {
|
||||
// Push 0s to avoid encoding empty vectors.
|
||||
if (reduce_ndim == 0) {
|
||||
@@ -231,8 +245,10 @@ void init_reduce(
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto kernel = get_reduce_init_kernel(
|
||||
d, "init_reduce_" + op_name + type_to_name(out), out);
|
||||
std::ostringstream kname;
|
||||
const std::string func_name = "init_reduce";
|
||||
kname << func_name << "_" << op_name << type_to_name(out);
|
||||
auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out);
|
||||
size_t nthreads = out.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
@@ -251,8 +267,7 @@ void all_reduce_dispatch(
|
||||
const std::string& op_name,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
std::vector<array>& copies) {
|
||||
const Stream& s) {
|
||||
// Set the kernel
|
||||
std::ostringstream kname;
|
||||
const std::string func_name = "all_reduce";
|
||||
@@ -293,7 +308,7 @@ void all_reduce_dispatch(
|
||||
// Allocate an intermediate tensor to hold results if needed
|
||||
array intermediate({n_rows}, out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
copies.push_back(intermediate);
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// 1st pass
|
||||
size_t row_size = (in_size + n_rows - 1) / n_rows;
|
||||
@@ -469,39 +484,11 @@ void strided_reduce_small(
|
||||
// Figure out the grid dims
|
||||
MTL::Size grid_dims, group_dims;
|
||||
|
||||
// Case 1: Small row small column
|
||||
if (args.reduction_size * args.non_col_reductions < 64 &&
|
||||
args.reduction_stride < 32) {
|
||||
grid_dims = output_grid_for_col_reduce(out, args);
|
||||
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
// Prepare the arguments for the kernel
|
||||
args.reduce_shape.push_back(args.reduction_size);
|
||||
args.reduce_strides.push_back(args.reduction_stride);
|
||||
args.reduce_ndim++;
|
||||
|
||||
// Case 2: Long row small column
|
||||
else if (args.reduction_size * args.non_col_reductions < 32) {
|
||||
auto out_grid_dims = output_grid_for_col_reduce(out, args);
|
||||
int threads_x =
|
||||
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
|
||||
int threadgroup_x = std::min(threads_x, 128);
|
||||
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
|
||||
group_dims = MTL::Size(threadgroup_x, 1, 1);
|
||||
}
|
||||
|
||||
// Case 3: Long row medium column
|
||||
else {
|
||||
args.reduce_shape.push_back(args.reduction_size);
|
||||
args.reduce_strides.push_back(args.reduction_stride);
|
||||
args.reduce_ndim++;
|
||||
int simdgroups =
|
||||
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
|
||||
int threadgroup_size = simdgroups * 32;
|
||||
auto out_grid_dims = output_grid_for_col_reduce(out, args);
|
||||
grid_dims =
|
||||
MTL::Size(threadgroup_size, out_grid_dims.width, out_grid_dims.height);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
// Set the kernel
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
std::ostringstream kname;
|
||||
const std::string func_name = "col_reduce_small";
|
||||
@@ -510,10 +497,113 @@ void strided_reduce_small(
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
const int n_reads = 4;
|
||||
size_t reduction_stride_blocks =
|
||||
(args.reduction_stride + n_reads - 1) / n_reads;
|
||||
size_t total = args.reduction_size * args.non_col_reductions;
|
||||
size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul);
|
||||
size_t threadgroup_y = std::min(
|
||||
8ul,
|
||||
std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total));
|
||||
|
||||
group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1);
|
||||
grid_dims = output_grid_for_col_reduce(out, args);
|
||||
grid_dims = MTL::Size(
|
||||
(reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x,
|
||||
grid_dims.width,
|
||||
grid_dims.height);
|
||||
|
||||
// Launch
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
args.encode(compute_encoder);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void strided_reduce_longcolumn(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
ColReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
|
||||
size_t outer_blocks = 32;
|
||||
if (total_reduction_size >= 32768) {
|
||||
outer_blocks = 128;
|
||||
}
|
||||
|
||||
// Prepare the temporary accumulator
|
||||
std::vector<int> intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.push_back(outer_blocks);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
args.reduce_shape.push_back(args.reduction_size);
|
||||
args.reduce_strides.push_back(args.reduction_stride);
|
||||
args.reduce_ndim++;
|
||||
|
||||
// Figure out the grid dims
|
||||
size_t out_size = out.size();
|
||||
size_t threadgroup_x = args.reduction_stride;
|
||||
size_t threadgroup_y =
|
||||
(args.non_col_reductions * args.reduction_size + outer_blocks - 1) /
|
||||
outer_blocks;
|
||||
threadgroup_y = std::min(32ul, threadgroup_y);
|
||||
|
||||
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
||||
MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks);
|
||||
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
|
||||
|
||||
// Set the kernel
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
std::ostringstream kname;
|
||||
const std::string func_name = "col_reduce_longcolumn";
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Launch
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
args.encode(compute_encoder);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Make the 2nd pass arguments and grid_dims
|
||||
ColReduceArgs second_args(intermediate);
|
||||
second_args.reduce_shape.push_back(outer_blocks);
|
||||
second_args.reduce_strides.push_back(out.size());
|
||||
second_args.reduce_ndim++;
|
||||
int BN = 32;
|
||||
grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1);
|
||||
group_dims = MTL::Size(256, 1, 1);
|
||||
|
||||
// Set the 2nd kernel
|
||||
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
|
||||
op_name + type_to_name(intermediate);
|
||||
kernel = get_reduce_kernel(
|
||||
d,
|
||||
second_kernel,
|
||||
"col_reduce_looped",
|
||||
op_name,
|
||||
intermediate,
|
||||
out,
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
second_args.encode(compute_encoder);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
@@ -532,9 +622,9 @@ void strided_reduce_looped(
|
||||
|
||||
// Figure out the grid dims
|
||||
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
||||
int BN = (args.reduction_stride <= 1024) ? 32 : 128;
|
||||
int BN = 32;
|
||||
int BM = 1024 / BN;
|
||||
int threadgroup_size = 4 * 32;
|
||||
int threadgroup_size = 8 * 32;
|
||||
MTL::Size grid_dims(
|
||||
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
|
||||
out_grid_size.width,
|
||||
@@ -558,6 +648,87 @@ void strided_reduce_looped(
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void strided_reduce_2pass(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
ColReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the temporary accumulator
|
||||
std::vector<int> intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.push_back(32);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
args.reduce_shape.push_back(args.reduction_size);
|
||||
args.reduce_strides.push_back(args.reduction_stride);
|
||||
args.reduce_ndim++;
|
||||
|
||||
// Figure out the grid dims
|
||||
size_t out_size = out.size() / args.reduction_stride;
|
||||
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
||||
int outer_blocks = 32;
|
||||
int BN = 32;
|
||||
int BM = 1024 / BN;
|
||||
int threadgroup_size = 8 * 32;
|
||||
MTL::Size grid_dims(
|
||||
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
|
||||
out_grid_size.width * outer_blocks,
|
||||
out_grid_size.height);
|
||||
MTL::Size group_dims(threadgroup_size, 1, 1);
|
||||
|
||||
// Set the kernel
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
std::ostringstream kname;
|
||||
const std::string func_name = "col_reduce_2pass";
|
||||
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
|
||||
<< op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Launch
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
args.encode(compute_encoder);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Make the 2nd pass arguments and grid_dims
|
||||
ColReduceArgs second_args(intermediate);
|
||||
second_args.reduce_shape.push_back(outer_blocks);
|
||||
second_args.reduce_strides.push_back(out.size());
|
||||
second_args.reduce_ndim++;
|
||||
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
|
||||
|
||||
// Set the 2nd kernel
|
||||
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
|
||||
op_name + type_to_name(intermediate);
|
||||
kernel = get_reduce_kernel(
|
||||
d,
|
||||
second_kernel,
|
||||
"col_reduce_looped",
|
||||
op_name,
|
||||
intermediate,
|
||||
out,
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
second_args.encode(compute_encoder);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void strided_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
@@ -570,11 +741,23 @@ void strided_reduce_general_dispatch(
|
||||
// Prepare the arguments for the kernel
|
||||
ColReduceArgs args(in, plan, axes);
|
||||
|
||||
if (args.reduction_stride < 32 ||
|
||||
args.reduction_size * args.non_col_reductions < 32) {
|
||||
// Small column
|
||||
if (args.reduction_size * args.non_col_reductions < 32) {
|
||||
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
// Long column but small row
|
||||
if (args.reduction_stride < 32 &&
|
||||
args.reduction_size * args.non_col_reductions >= 1024) {
|
||||
return strided_reduce_longcolumn(
|
||||
in, out, op_name, args, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
if (args.reduction_size * args.non_col_reductions > 256 &&
|
||||
out.size() / 32 < 1024) {
|
||||
return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
@@ -620,7 +803,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Reduce
|
||||
if (in.size() > 0) {
|
||||
std::vector<array> copies;
|
||||
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||
|
||||
// If it is a general reduce then copy the input to a contiguous array and
|
||||
@@ -632,7 +814,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (plan.type == GeneralReduce) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
copies.push_back(in_copy);
|
||||
d.add_temporary(in_copy, s.index);
|
||||
in = in_copy;
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
}
|
||||
@@ -640,7 +822,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Reducing over everything and the data is all there no broadcasting or
|
||||
// slicing etc.
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s, copies);
|
||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
// At least the last dimension is row contiguous and we are reducing over
|
||||
@@ -659,8 +841,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
strided_reduce_general_dispatch(
|
||||
in, out, op_name, plan, axes_, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
// Nothing to reduce just initialize the output
|
||||
|
@@ -16,8 +16,7 @@ void all_reduce_dispatch(
|
||||
const std::string& op_name,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
std::vector<array>& copies);
|
||||
const Stream& s);
|
||||
|
||||
void row_reduce_general_dispatch(
|
||||
const array& in,
|
||||
|
@@ -162,7 +162,8 @@ void sdpa_vector(
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t stride = k.strides()[1];
|
||||
size_t k_stride = k.strides()[1];
|
||||
size_t v_stride = v.strides()[1];
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(1, B, 1);
|
||||
|
||||
@@ -178,8 +179,9 @@ void sdpa_vector(
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 7);
|
||||
compute_encoder->setBytes(&k_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&v_stride, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 8);
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
@@ -69,6 +69,14 @@ array rms_norm(
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (weight.size() != x.shape(-1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] weight must have the same size as the last dimension of"
|
||||
" x but has "
|
||||
<< weight.size() << " elements.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = result_type(x, weight);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -1683,48 +1682,58 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
|
||||
auto gather_axes = axes_;
|
||||
auto slice_sizes = slice_sizes_;
|
||||
auto src_vmapped = axes[0] >= 0;
|
||||
auto indices_vmapped =
|
||||
std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
|
||||
auto out_ax =
|
||||
*std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; });
|
||||
auto ind_vmap_ax_ptr =
|
||||
std::find_if(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
|
||||
int out_ax = -1;
|
||||
bool indices_vmapped = (ind_vmap_ax_ptr != axes.end());
|
||||
if (indices_vmapped) {
|
||||
out_ax = *ind_vmap_ax_ptr;
|
||||
} else if (src_vmapped) {
|
||||
out_ax = axes[0];
|
||||
}
|
||||
|
||||
// Reorder all the index arrays so the vmap axis is in the same spot.
|
||||
for (int i = 1; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
|
||||
if (indices_vmapped) {
|
||||
for (int i = 1; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
|
||||
} else if (axes[i] < 0) {
|
||||
indices[i - 1] = expand_dims(indices[i - 1], out_ax, stream());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int idx_dims = indices.empty() ? 0 : indices[0].ndim();
|
||||
|
||||
if (src_vmapped) {
|
||||
int max_dims = 0;
|
||||
for (auto& idx : indices) {
|
||||
max_dims = std::max(static_cast<int>(idx.ndim()), max_dims);
|
||||
}
|
||||
auto new_ax_loc =
|
||||
std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) {
|
||||
return a >= out_ax;
|
||||
});
|
||||
for (; new_ax_loc < gather_axes.end(); new_ax_loc++) {
|
||||
(*new_ax_loc)++;
|
||||
for (auto& ax : gather_axes) {
|
||||
if (ax >= axes[0]) {
|
||||
ax++;
|
||||
}
|
||||
}
|
||||
if (indices_vmapped) {
|
||||
// Make a new index array for the vmapped dimension
|
||||
auto vmap_inds = arange(0, src.shape(axes[0]), stream());
|
||||
// Reshape it so it broadcasts with other index arrays
|
||||
{
|
||||
auto shape = std::vector<int>(idx_dims, 1);
|
||||
shape[out_ax] = vmap_inds.size();
|
||||
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
|
||||
}
|
||||
// Update gather axes and slice sizes accordingly
|
||||
auto shape = std::vector<int>(max_dims - out_ax, 1);
|
||||
auto vmap_inds = arange(0, src.shape(out_ax), stream());
|
||||
shape[0] = vmap_inds.shape(0);
|
||||
vmap_inds = reshape(vmap_inds, shape, stream());
|
||||
slice_sizes.insert(slice_sizes.begin() + out_ax, 1);
|
||||
auto new_ax_idx = new_ax_loc - gather_axes.begin();
|
||||
gather_axes.insert(new_ax_loc, out_ax);
|
||||
indices.insert(indices.begin() + new_ax_idx, vmap_inds);
|
||||
slice_sizes.insert(slice_sizes.begin() + axes[0], 1);
|
||||
gather_axes.push_back(axes[0]);
|
||||
indices.push_back(vmap_inds);
|
||||
} else {
|
||||
slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0]));
|
||||
out_ax = max_dims + axes[0];
|
||||
slice_sizes.insert(slice_sizes.begin() + out_ax, src.shape(out_ax));
|
||||
out_ax += idx_dims;
|
||||
}
|
||||
}
|
||||
return {{gather(src, indices, gather_axes, slice_sizes, stream())}, {out_ax}};
|
||||
auto out = gather(src, indices, gather_axes, slice_sizes, stream());
|
||||
if (src_vmapped && indices_vmapped) {
|
||||
out = squeeze(out, idx_dims + axes[0], stream());
|
||||
}
|
||||
return {{out}, {out_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Gather::vjp(
|
||||
|
@@ -278,7 +278,9 @@ void init_fast(nb::module_& parent_module) {
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``.
|
||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
@@ -300,6 +302,8 @@ void init_fast(nb::module_& parent_module) {
|
||||
R"pbdoc(
|
||||
A jit-compiled custom Metal kernel defined from a source string.
|
||||
|
||||
Full documentation: :ref:`custom_metal_kernels`.
|
||||
|
||||
Args:
|
||||
name (str): Name for the kernel.
|
||||
input_names (List[str]): The parameter names of the inputs in the
|
||||
|
@@ -803,9 +803,10 @@ auto mlx_slice_update(
|
||||
// Pre process tuple
|
||||
auto upd = to_array(v, src.dtype());
|
||||
|
||||
// Remove leading singletons dimensions from the update
|
||||
// Remove extra leading singletons dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < upd.ndim() && upd.shape(s) == 1; s++) {
|
||||
for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim();
|
||||
s++) {
|
||||
};
|
||||
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
|
||||
up_shape = up_shape.empty() ? std::vector{1} : up_shape;
|
||||
|
@@ -1771,6 +1771,19 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def fun():
|
||||
a = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
b, _ = mx.divmod(a, a)
|
||||
return mx.log(b)
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def test_add_numpy(self):
|
||||
x = mx.array(1)
|
||||
y = np.array(2, dtype=np.int32)
|
||||
|
@@ -308,6 +308,11 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||
|
||||
# Wrong size w raises
|
||||
with self.assertRaises(ValueError):
|
||||
x = mx.random.uniform(shape=(1, 5))
|
||||
mx.fast.rms_norm(x, mx.ones((4,)), 1e-5)
|
||||
|
||||
def test_rms_norm_grad(self):
|
||||
D = 32
|
||||
eps = 1e-5
|
||||
|
@@ -167,6 +167,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
||||
|
||||
q = mx.random.normal(shape=(1, 32, 1, Dk))
|
||||
k = mx.random.normal(shape=(1, 32, 32, Dk))
|
||||
v = mx.random.normal(shape=(1, 32, 128, Dk))
|
||||
|
||||
atol = 1e-6
|
||||
y = mlx_primitives_sdpa(q, k, v[:, :, :32], scale)
|
||||
y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale)
|
||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
@@ -163,6 +163,31 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qvm_splitk(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[128], # M
|
||||
[16384], # N
|
||||
[1, 3], # B
|
||||
)
|
||||
for group_size, bits, M, N, B in tests:
|
||||
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, False, group_size, bits
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
def test_throw(self):
|
||||
x = mx.random.normal(shape=(10, 512))
|
||||
w = mx.random.normal(shape=(32, 512))
|
||||
|
@@ -370,6 +370,51 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
|
||||
)
|
||||
|
||||
def test_vmap_gather(self):
|
||||
def gather(a, idx):
|
||||
return a[idx]
|
||||
|
||||
a = mx.array([[1, 2], [3, 4]])
|
||||
idx = mx.array(0)
|
||||
out = mx.vmap(gather, (0, None))(a, idx)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 3])))
|
||||
|
||||
out = mx.vmap(gather, (1, None))(a, idx)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 2])))
|
||||
|
||||
idx = mx.array([0, 1])
|
||||
out = mx.vmap(gather, (0, 0))(a, idx)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 4])))
|
||||
|
||||
a = mx.ones((2, 3, 4))
|
||||
idx = mx.zeros(4, mx.int32)
|
||||
out = mx.vmap(gather, (2, 0))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 3))
|
||||
|
||||
f = mx.vmap(gather, (0, None))
|
||||
f = mx.vmap(gather, (0, 0))
|
||||
out = f(mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32))
|
||||
self.assertEqual(out.shape, (2, 4))
|
||||
|
||||
def gather(a, idxa, idxb):
|
||||
return a[idxa, idxb]
|
||||
|
||||
a = mx.ones((2, 3, 4))
|
||||
idxa = mx.zeros((2, 3), mx.int32)
|
||||
idxb = mx.zeros(3, mx.int32)
|
||||
out = mx.vmap(gather, (0, 0, None))(a, idxa, idxb)
|
||||
self.assertEqual(out.shape, (2, 3))
|
||||
|
||||
idxa = mx.zeros((3, 1, 2), mx.int32)
|
||||
idxb = mx.zeros((2, 3, 1, 2), mx.int32)
|
||||
out = mx.vmap(gather, (0, None, 0))(a, idxa, idxb)
|
||||
self.assertEqual(out.shape, (2, 3, 1, 2))
|
||||
|
||||
idxa = mx.zeros((3, 1, 2), mx.int32)
|
||||
idxb = mx.zeros((3, 1, 2, 2), mx.int32)
|
||||
out = mx.vmap(gather, (0, None, 3))(a, idxa, idxb)
|
||||
self.assertEqual(out.shape, (2, 3, 1, 2))
|
||||
|
||||
def test_vmap_scatter(self):
|
||||
def scatter(a):
|
||||
a[mx.array(0)] = mx.array(0.0)
|
||||
|
Reference in New Issue
Block a user