Compare commits

...

14 Commits

Author SHA1 Message Date
Alex Barron
726dbd9267 v0.20.0 (#1565) 2024-11-05 12:37:57 -08:00
Awni Hannun
54f05e7195 Fix gather vmap (#1563)
* fix gather

* fix
2024-11-05 11:29:20 -08:00
Alex Barron
26be608470 Add split_k qvm for long context (#1564)
* Add splitk qvm

* configurable splitk

* tuning

* remove extra instantiation

* remove refactor

* separate test

* cpu tolerance
2024-11-05 11:25:19 -08:00
Angelos Katharopoulos
248431eb3c Reductions update (#1351) 2024-11-04 22:25:16 -08:00
Awni Hannun
76f275b4df error in rms for wrong size (#1562) 2024-11-04 13:24:02 -08:00
Awni Hannun
f1951d6cce Use fewer barriers (#1561)
* use fewer barriers

* comment
2024-11-04 10:26:49 -08:00
Angelos Katharopoulos
62f297b51d Sdpa fix (#1558) 2024-11-02 21:25:46 -07:00
Awni Hannun
09bc32f62f No extra reshape (#1557)
* no extra reshape

* lint
2024-11-02 19:07:20 -07:00
Chris Offner
46d8b16ab4 Fix vmap example in docs (#1556) 2024-11-02 17:44:14 -07:00
Chris Offner
42533931fa Fix typo "it's" -> "its" (#1555) 2024-11-02 06:06:34 -07:00
Awni Hannun
9bd3a7102f add python 3.13 to circle (#1553) 2024-11-01 20:55:35 -07:00
Alex Barron
9e516b71ea Add dispatchThreads to custom kernel doc (#1551)
* add dispatchThreads info

* update

* add link
2024-11-01 13:07:48 -07:00
Awni Hannun
eac961ddb1 patch (#1550) 2024-10-31 16:10:14 -07:00
Awni Hannun
57c6aa7188 fix multi output leak (#1548) 2024-10-31 09:32:01 -07:00
33 changed files with 855 additions and 265 deletions

View File

@@ -349,7 +349,7 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
- build_documentation:
@@ -386,7 +386,7 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
weekly_build:
when:
@@ -397,7 +397,7 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
@@ -409,5 +409,5 @@ workflows:
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.19.2)
set(MLX_VERSION 0.20.0)
endif()
# --------------------- Processor tests -------------------------

View File

@@ -144,6 +144,13 @@ def reduction(op, axis, x):
mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x):
ys = []
for i in range(100):
@@ -505,5 +512,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError("Unknown benchmark")

View File

@@ -1,3 +1,5 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
@@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Using Shape/Strides

View File

@@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
@@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify

View File

@@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which outputs
general, MLX has limited support for operations for which output
*shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`.

View File

@@ -109,7 +109,7 @@ Here is a concrete example:
An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array.

View File

@@ -271,6 +271,9 @@ array::ArrayDesc::~ArrayDesc() {
for (array& a : ad.inputs) {
if (a.array_desc_) {
input_map.insert({a.id(), a});
for (auto& s : a.siblings()) {
input_map.insert({s.id(), s});
}
}
}
ad.inputs.clear();

View File

@@ -136,13 +136,8 @@ void CommandEncoder::set_input_array(
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
// Insert a barrier
enc_->memoryBarrier(&r_buf, 1);
// Remove the output
outputs_.erase(it);
}
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
@@ -161,19 +156,32 @@ void CommandEncoder::set_output_array(
if (concurrent_) {
concurrent_outputs_.insert(buf);
} else {
outputs_.insert(buf);
next_outputs_.insert(buf);
}
}
void CommandEncoder::maybeInsertBarrier() {
if (needs_barrier_) {
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
needs_barrier_ = false;
prev_outputs_ = std::move(next_outputs_);
} else {
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
}
next_outputs_.clear();
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims);
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreads(grid_dims, group_dims);
}

View File

@@ -49,7 +49,7 @@ struct CommandEncoder {
}
~ConcurrentContext() {
enc.concurrent_ = false;
enc.outputs_.insert(
enc.prev_outputs_.insert(
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
enc.concurrent_outputs_.clear();
}
@@ -66,6 +66,7 @@ struct CommandEncoder {
void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier();
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
@@ -84,8 +85,10 @@ struct CommandEncoder {
private:
MTL::ComputeCommandEncoder* enc_;
bool needs_barrier_{false};
bool concurrent_{false};
std::unordered_set<MTL::Resource*> outputs_;
std::unordered_set<MTL::Resource*> prev_outputs_;
std::unordered_set<MTL::Resource*> next_outputs_;
std::unordered_set<MTL::Resource*> concurrent_outputs_;
std::unordered_set<const void*> all_inputs_;
std::unordered_set<const void*> all_outputs_;

View File

@@ -319,16 +319,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& out) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name(out);
op_type[0] = std::toupper(op_name(out)[0]);
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, "init_reduce", out_type, op);
kernel_name, func_name, out_type, op);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);

View File

@@ -79,6 +79,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& out);
MTL::ComputePipelineState* get_reduce_kernel(

View File

@@ -650,8 +650,8 @@ METAL_FUNC void qvm_impl(
const device T* biases,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const int in_vec_size,
const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1298,6 +1298,61 @@ template <typename T, const int group_size, const int bits, bool batched>
simd_lid);
}
template <typename T, const int group_size, const int bits, int split_k = 32>
[[kernel]] void qvm_split_k(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& final_block_size [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
// When (in_vec_size % split_k != 0) the final block needs to be smaller
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
qvm_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <
typename T,
const int group_size,

View File

@@ -51,6 +51,15 @@
D, \
batched)
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
name, \
type, \
group_size, \
bits, \
split_k)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
@@ -84,11 +93,16 @@
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits)
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \

View File

@@ -113,9 +113,12 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
@@ -123,9 +126,14 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \

View File

@@ -1,11 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
template <
typename T,
typename U,
typename Op,
int NDIMS,
int N_READS = REDUCE_N_READS>
template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
@@ -20,170 +15,128 @@ template <
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[thread_position_in_grid]],
uint3 tsize [[threads_per_grid]]) {
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
constexpr int n_reads = 4;
Op op;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
// Case 1: Small row small column
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
U totals[31];
for (int i = 0; i < 31; i++) {
totals[i] = Op::init;
U totals[n_reads];
for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
if (column >= reduction_stride) {
return;
}
bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total_rows = non_col_reductions * reduction_size;
loop.next(lid.y, reduce_shape, reduce_strides);
for (size_t r = lid.y; r < total_rows; r += lsize.y) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(lsize.y, reduce_shape, reduce_strides);
}
short stride = reduction_stride;
short size = reduction_size;
short blocks = stride / N_READS;
short extra = stride - blocks * N_READS;
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
for (short i = 0; i < size; i++) {
for (short j = 0; j < blocks; j++) {
for (short k = 0; k < N_READS; k++) {
totals[j * N_READS + k] =
op(totals[j * N_READS + k],
static_cast<U>(row[i * stride + j * N_READS + k]));
}
}
for (short k = 0; k < extra; k++) {
totals[blocks * N_READS + k] =
op(totals[blocks * N_READS + k],
static_cast<U>(row[i * stride + blocks * N_READS + k]));
if (lsize.y > 1) {
// lsize.y should be <= 8
threadgroup U shared_vals[32 * 8 * n_reads];
for (int i = 0; i < n_reads; i++) {
shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (lid.y == 0) {
for (int i = 0; i < n_reads; i++) {
totals[i] = shared_vals[lid.x * n_reads + i];
}
for (uint j = 1; j < lsize.y; j++) {
for (int i = 0; i < n_reads; i++) {
totals[i] =
op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
totals[i]);
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride;
for (short j = 0; j < stride; j++) {
out[j] = totals[j];
}
}
// Case 2: Long row small column
else if (reduction_size * non_col_reductions < 32) {
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
short size = reduction_size;
size_t offset = size_t(tid.x) * N_READS;
bool safe = offset + N_READS <= reduction_stride;
short extra = reduction_stride - offset;
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (short i = 0; i < size; i++) {
for (short j = 0; j < N_READS; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
} else {
for (short i = 0; i < size; i++) {
for (short j = 0; j < extra; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride + offset;
if (lid.y == 0) {
out += out_idx * reduction_stride + column;
if (safe) {
for (short i = 0; i < N_READS; i++) {
for (int i = 0; i < n_reads; i++) {
out[i] = totals[i];
}
} else {
for (short i = 0; i < extra; i++) {
for (int i = 0; column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}
// Case 3: Long row medium column
else {
threadgroup U shared_vals[1024];
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
short stride = reduction_stride;
short lid = simd_group_id * simd_size + simd_lane_id;
short2 tile((stride + N_READS - 1) / N_READS, 32);
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
short sm_stride = tile.x * N_READS;
bool safe = offset.x + N_READS <= stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
// Read cooperatively and contiguously and aggregate the partial results.
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y, reduce_shape, reduce_strides);
for (size_t r = offset.y; r < total; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < N_READS; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(simd_size, reduce_shape, reduce_strides);
}
// Each thread holds N_READS partial results but the simdgroups are not
// aligned to do the reduction across the simdgroup so we write our results
// in the shared memory and read them back according to the simdgroup.
for (int i = 0; i < N_READS; i++) {
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_READS; i++) {
totals[i] = op.simd_reduce(
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
}
// Write the output.
if (simd_lane_id == 0) {
short column = simd_group_id * N_READS;
out += out_idx * reduction_stride + column;
if (column + N_READS <= stride) {
for (int i = 0; i < N_READS; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; column + i < stride; i++) {
out[i] = totals[i];
}
}
template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_longcolumn(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
Op op;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
size_t out_idx = gid.x + gsize.x * size_t(gid.y);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + lid.x;
U total = Op::init;
size_t total_rows = non_col_reductions * reduction_size;
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
r += lsize.y * gsize.z) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
total = op(static_cast<U>(*row), total);
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
}
threadgroup U shared_vals[32 * 32];
shared_vals[lid.y * lsize.x + lid.x] = total;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (lid.y == 0) {
for (uint i = 1; i < lsize.y; i++) {
total = op(total, shared_vals[i * lsize.x + lid.x]);
}
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
}
}
@@ -216,7 +169,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_simdgroups = 4;
constexpr int n_simdgroups = 8;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;
@@ -329,3 +282,103 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
}
}
}
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
[[kernel]] void col_reduce_2pass(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_simdgroups = 8;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;
constexpr int n_outputs = BN / n_simdgroups;
constexpr short outer_blocks = 32;
static_assert(BM == 32, "BM should be equal to 32");
threadgroup U shared_vals[BN * BM];
U totals[n_reads];
looped_elem_to_loc<NDIMS> loop;
const device T* row;
for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}
short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride;
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
size_t block_idx = full_idx / out_size;
size_t out_idx = full_idx % out_size;
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
for (size_t r = offset.y + block_idx * BM; r < total;
r += outer_blocks * BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
}
// We can use a simd reduction to accumulate across BM so each thread writes
// the partial output to SM and then each simdgroup does BN / n_simdgroups
// accumulations.
for (int i = 0; i < n_reads; i++) {
shared_vals[offset.y * BN + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
for (int i = 0; i < n_outputs; i++) {
totals[i] =
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
}
// Write the output.
if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x;
out += full_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; out_column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}

View File

@@ -936,6 +936,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant size_t& v_stride, \
const constant float& scale, \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \

View File

@@ -13,6 +13,7 @@ template <typename T, int D>
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -38,7 +39,7 @@ template <typename T, int D>
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator

View File

@@ -97,6 +97,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string&,
const std::string&,
const array&) {
return d.get_kernel(kernel_name);
}

View File

@@ -6,6 +6,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
@@ -148,6 +149,125 @@ void launch_qmm(
d.add_temporaries(std::move(copies), s.index);
}
void qvm_split_k(
const std::vector<array>& inputs,
array& out,
int group_size,
int bits,
int D,
int O,
int B,
int N,
const Stream& s) {
int split_k = D > 8192 ? 32 : 8;
int split_D = (D + split_k - 1) / split_k;
N *= split_k;
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
// Ensure that the last two dims are row contiguous.
// TODO: Check if we really need this for x as well...
std::vector<array> copies;
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
int x_batch_ndims = x.ndim() - 2;
auto x_shape = x.shape();
auto x_strides = x.strides();
int w_batch_ndims = w.ndim() - 2;
auto w_shape = w.shape();
auto w_strides = w.strides();
auto s_strides = scales.strides();
auto b_strides = biases.strides();
// Add split_k dim with reshapes
x_shape.insert(x_shape.end() - 2, split_k);
x_shape.back() /= split_k;
x_strides.insert(x_strides.end() - 2, split_D);
x_strides[x.ndim() - 1] = split_D;
x_batch_ndims += 1;
w_shape.insert(w_shape.end() - 2, split_k);
w_shape[w.ndim() - 1] /= split_k;
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
w_batch_ndims += 1;
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
int final_block_size = D - (split_k - 1) * split_D;
auto& d = metal::device(s.device);
auto temp_shape = out.shape();
temp_shape.insert(temp_shape.end() - 2, split_k);
array intermediate(temp_shape, x.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_"
<< bits << "_spk_" << split_k;
auto template_def = get_template_definition(
kname.str(), "qvm_split_k", type_string, group_size, bits, split_k);
// Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4);
compute_encoder->setBytes(&split_D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
set_vector_bytes(compute_encoder, x_shape, 8);
set_vector_bytes(compute_encoder, x_strides, 9);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, w_shape, 11);
set_vector_bytes(compute_encoder, w_strides, 12);
set_vector_bytes(compute_encoder, s_strides, 13);
set_vector_bytes(compute_encoder, b_strides, 14);
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
int axis = intermediate.ndim() - 3;
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce,
{intermediate.shape(axis)},
{intermediate.strides(axis)});
strided_reduce_general_dispatch(
intermediate, out, "sum", plan, {axis}, compute_encoder, d, s);
}
void qmm_op(
const std::vector<array>& inputs,
array& out,
@@ -211,7 +331,9 @@ void qmm_op(
aligned = true;
}
} else {
if (B < 4) {
if (B < 4 && D >= 1024 && !gather) {
return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s);
} else if (B < 4) {
name += "qvm";
int bo = 64;
int bd = 32;

View File

@@ -141,6 +141,20 @@ struct ColReduceArgs {
ndim = shape.size();
}
/**
* Create the col reduce arguments for reducing the 1st axis of the row
* contiguous intermediate array.
*/
ColReduceArgs(const array& intermediate) {
assert(intermediate.flags().row_contiguous);
reduction_size = intermediate.shape(0);
reduction_stride = intermediate.size() / reduction_size;
non_col_reductions = 1;
reduce_ndim = 0;
ndim = 0;
}
void encode(CommandEncoder& compute_encoder) {
// Push 0s to avoid encoding empty vectors.
if (reduce_ndim == 0) {
@@ -231,8 +245,10 @@ void init_reduce(
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto kernel = get_reduce_init_kernel(
d, "init_reduce_" + op_name + type_to_name(out), out);
std::ostringstream kname;
const std::string func_name = "init_reduce";
kname << func_name << "_" << op_name << type_to_name(out);
auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out);
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@@ -251,8 +267,7 @@ void all_reduce_dispatch(
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s,
std::vector<array>& copies) {
const Stream& s) {
// Set the kernel
std::ostringstream kname;
const std::string func_name = "all_reduce";
@@ -293,7 +308,7 @@ void all_reduce_dispatch(
// Allocate an intermediate tensor to hold results if needed
array intermediate({n_rows}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
copies.push_back(intermediate);
d.add_temporary(intermediate, s.index);
// 1st pass
size_t row_size = (in_size + n_rows - 1) / n_rows;
@@ -469,39 +484,11 @@ void strided_reduce_small(
// Figure out the grid dims
MTL::Size grid_dims, group_dims;
// Case 1: Small row small column
if (args.reduction_size * args.non_col_reductions < 64 &&
args.reduction_stride < 32) {
grid_dims = output_grid_for_col_reduce(out, args);
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Case 2: Long row small column
else if (args.reduction_size * args.non_col_reductions < 32) {
auto out_grid_dims = output_grid_for_col_reduce(out, args);
int threads_x =
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
int threadgroup_x = std::min(threads_x, 128);
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
group_dims = MTL::Size(threadgroup_x, 1, 1);
}
// Case 3: Long row medium column
else {
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
int simdgroups =
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
int threadgroup_size = simdgroups * 32;
auto out_grid_dims = output_grid_for_col_reduce(out, args);
grid_dims =
MTL::Size(threadgroup_size, out_grid_dims.width, out_grid_dims.height);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_small";
@@ -510,10 +497,113 @@ void strided_reduce_small(
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
const int n_reads = 4;
size_t reduction_stride_blocks =
(args.reduction_stride + n_reads - 1) / n_reads;
size_t total = args.reduction_size * args.non_col_reductions;
size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul);
size_t threadgroup_y = std::min(
8ul,
std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total));
group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1);
grid_dims = output_grid_for_col_reduce(out, args);
grid_dims = MTL::Size(
(reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x,
grid_dims.width,
grid_dims.height);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void strided_reduce_longcolumn(
const array& in,
array& out,
const std::string& op_name,
ColReduceArgs& args,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
size_t outer_blocks = 32;
if (total_reduction_size >= 32768) {
outer_blocks = 128;
}
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Figure out the grid dims
size_t out_size = out.size();
size_t threadgroup_x = args.reduction_stride;
size_t threadgroup_y =
(args.non_col_reductions * args.reduction_size + outer_blocks - 1) /
outer_blocks;
threadgroup_y = std::min(32ul, threadgroup_y);
auto out_grid_size = output_grid_for_col_reduce(out, args);
MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks);
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_longcolumn";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
second_args.reduce_shape.push_back(outer_blocks);
second_args.reduce_strides.push_back(out.size());
second_args.reduce_ndim++;
int BN = 32;
grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1);
group_dims = MTL::Size(256, 1, 1);
// Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
second_kernel,
"col_reduce_looped",
op_name,
intermediate,
out,
1,
32,
32);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
@@ -532,9 +622,9 @@ void strided_reduce_looped(
// Figure out the grid dims
auto out_grid_size = output_grid_for_col_reduce(out, args);
int BN = (args.reduction_stride <= 1024) ? 32 : 128;
int BN = 32;
int BM = 1024 / BN;
int threadgroup_size = 4 * 32;
int threadgroup_size = 8 * 32;
MTL::Size grid_dims(
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
out_grid_size.width,
@@ -558,6 +648,87 @@ void strided_reduce_looped(
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_2pass(
const array& in,
array& out,
const std::string& op_name,
ColReduceArgs& args,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Figure out the grid dims
size_t out_size = out.size() / args.reduction_stride;
auto out_grid_size = output_grid_for_col_reduce(out, args);
int outer_blocks = 32;
int BN = 32;
int BM = 1024 / BN;
int threadgroup_size = 8 * 32;
MTL::Size grid_dims(
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
out_grid_size.width * outer_blocks,
out_grid_size.height);
MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_2pass";
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
<< op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
second_args.reduce_shape.push_back(outer_blocks);
second_args.reduce_strides.push_back(out.size());
second_args.reduce_ndim++;
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
// Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
second_kernel,
"col_reduce_looped",
op_name,
intermediate,
out,
1,
32,
32);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_general_dispatch(
const array& in,
array& out,
@@ -570,11 +741,23 @@ void strided_reduce_general_dispatch(
// Prepare the arguments for the kernel
ColReduceArgs args(in, plan, axes);
if (args.reduction_stride < 32 ||
args.reduction_size * args.non_col_reductions < 32) {
// Small column
if (args.reduction_size * args.non_col_reductions < 32) {
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
}
// Long column but small row
if (args.reduction_stride < 32 &&
args.reduction_size * args.non_col_reductions >= 1024) {
return strided_reduce_longcolumn(
in, out, op_name, args, compute_encoder, d, s);
}
if (args.reduction_size * args.non_col_reductions > 256 &&
out.size() / 32 < 1024) {
return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s);
}
return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
}
@@ -620,7 +803,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reduce
if (in.size() > 0) {
std::vector<array> copies;
ReductionPlan plan = get_reduction_plan(in, axes_);
// If it is a general reduce then copy the input to a contiguous array and
@@ -632,7 +814,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
copies.push_back(in_copy);
d.add_temporary(in_copy, s.index);
in = in_copy;
plan = get_reduction_plan(in, axes_);
}
@@ -640,7 +822,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reducing over everything and the data is all there no broadcasting or
// slicing etc.
if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s, copies);
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
}
// At least the last dimension is row contiguous and we are reducing over
@@ -659,8 +841,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d, s);
}
d.add_temporaries(std::move(copies), s.index);
}
// Nothing to reduce just initialize the output

View File

@@ -16,8 +16,7 @@ void all_reduce_dispatch(
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s,
std::vector<array>& copies);
const Stream& s);
void row_reduce_general_dispatch(
const array& in,

View File

@@ -162,7 +162,8 @@ void sdpa_vector(
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t stride = k.strides()[1];
size_t k_stride = k.strides()[1];
size_t v_stride = v.strides()[1];
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1);
@@ -178,8 +179,9 @@ void sdpa_vector(
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
compute_encoder->setBytes(&N, sizeof(int), 5);
compute_encoder->setBytes(&stride, sizeof(size_t), 6);
compute_encoder->setBytes(&scale, sizeof(float), 7);
compute_encoder->setBytes(&k_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&v_stride, sizeof(size_t), 7);
compute_encoder->setBytes(&scale, sizeof(float), 8);
// Launch
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);

View File

@@ -69,6 +69,14 @@ array rms_norm(
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.size() != x.shape(-1)) {
std::ostringstream msg;
msg << "[rms_norm] weight must have the same size as the last dimension of"
" x but has "
<< weight.size() << " elements.";
throw std::invalid_argument(msg.str());
}
auto out_type = result_type(x, weight);
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -1683,48 +1682,58 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
auto gather_axes = axes_;
auto slice_sizes = slice_sizes_;
auto src_vmapped = axes[0] >= 0;
auto indices_vmapped =
std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
auto out_ax =
*std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; });
auto ind_vmap_ax_ptr =
std::find_if(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
int out_ax = -1;
bool indices_vmapped = (ind_vmap_ax_ptr != axes.end());
if (indices_vmapped) {
out_ax = *ind_vmap_ax_ptr;
} else if (src_vmapped) {
out_ax = axes[0];
}
// Reorder all the index arrays so the vmap axis is in the same spot.
for (int i = 1; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
if (indices_vmapped) {
for (int i = 1; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
} else if (axes[i] < 0) {
indices[i - 1] = expand_dims(indices[i - 1], out_ax, stream());
}
}
}
int idx_dims = indices.empty() ? 0 : indices[0].ndim();
if (src_vmapped) {
int max_dims = 0;
for (auto& idx : indices) {
max_dims = std::max(static_cast<int>(idx.ndim()), max_dims);
}
auto new_ax_loc =
std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) {
return a >= out_ax;
});
for (; new_ax_loc < gather_axes.end(); new_ax_loc++) {
(*new_ax_loc)++;
for (auto& ax : gather_axes) {
if (ax >= axes[0]) {
ax++;
}
}
if (indices_vmapped) {
// Make a new index array for the vmapped dimension
auto vmap_inds = arange(0, src.shape(axes[0]), stream());
// Reshape it so it broadcasts with other index arrays
{
auto shape = std::vector<int>(idx_dims, 1);
shape[out_ax] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
}
// Update gather axes and slice sizes accordingly
auto shape = std::vector<int>(max_dims - out_ax, 1);
auto vmap_inds = arange(0, src.shape(out_ax), stream());
shape[0] = vmap_inds.shape(0);
vmap_inds = reshape(vmap_inds, shape, stream());
slice_sizes.insert(slice_sizes.begin() + out_ax, 1);
auto new_ax_idx = new_ax_loc - gather_axes.begin();
gather_axes.insert(new_ax_loc, out_ax);
indices.insert(indices.begin() + new_ax_idx, vmap_inds);
slice_sizes.insert(slice_sizes.begin() + axes[0], 1);
gather_axes.push_back(axes[0]);
indices.push_back(vmap_inds);
} else {
slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0]));
out_ax = max_dims + axes[0];
slice_sizes.insert(slice_sizes.begin() + out_ax, src.shape(out_ax));
out_ax += idx_dims;
}
}
return {{gather(src, indices, gather_axes, slice_sizes, stream())}, {out_ax}};
auto out = gather(src, indices, gather_axes, slice_sizes, stream());
if (src_vmapped && indices_vmapped) {
out = squeeze(out, idx_dims + axes[0], stream());
}
return {{out}, {out_ax}};
}
std::vector<array> Gather::vjp(

View File

@@ -278,7 +278,9 @@ void init_fast(nb::module_& parent_module) {
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
@@ -300,6 +302,8 @@ void init_fast(nb::module_& parent_module) {
R"pbdoc(
A jit-compiled custom Metal kernel defined from a source string.
Full documentation: :ref:`custom_metal_kernels`.
Args:
name (str): Name for the kernel.
input_names (List[str]): The parameter names of the inputs in the

View File

@@ -803,9 +803,10 @@ auto mlx_slice_update(
// Pre process tuple
auto upd = to_array(v, src.dtype());
// Remove leading singletons dimensions from the update
// Remove extra leading singletons dimensions from the update
int s = 0;
for (; s < upd.ndim() && upd.shape(s) == 1; s++) {
for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim();
s++) {
};
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
up_shape = up_shape.empty() ? std::vector{1} : up_shape;

View File

@@ -1771,6 +1771,19 @@ class TestArray(mlx_tests.MLXTestCase):
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def fun():
a = mx.array([1.0, 2.0, 3.0, 4.0])
b, _ = mx.divmod(a, a)
return mx.log(b)
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def test_add_numpy(self):
x = mx.array(1)
y = np.array(2, dtype=np.int32)

View File

@@ -308,6 +308,11 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
# Wrong size w raises
with self.assertRaises(ValueError):
x = mx.random.uniform(shape=(1, 5))
mx.fast.rms_norm(x, mx.ones((4,)), 1e-5)
def test_rms_norm_grad(self):
D = 32
eps = 1e-5

View File

@@ -167,6 +167,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
q = mx.random.normal(shape=(1, 32, 1, Dk))
k = mx.random.normal(shape=(1, 32, 32, Dk))
v = mx.random.normal(shape=(1, 32, 128, Dk))
atol = 1e-6
y = mlx_primitives_sdpa(q, k, v[:, :, :32], scale)
y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
if __name__ == "__main__":
unittest.main(failfast=True)

View File

@@ -163,6 +163,31 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm_splitk(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[128], # M
[16384], # N
[1, 3], # B
)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))

View File

@@ -370,6 +370,51 @@ class TestVmap(mlx_tests.MLXTestCase):
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
)
def test_vmap_gather(self):
def gather(a, idx):
return a[idx]
a = mx.array([[1, 2], [3, 4]])
idx = mx.array(0)
out = mx.vmap(gather, (0, None))(a, idx)
self.assertTrue(mx.array_equal(out, mx.array([1, 3])))
out = mx.vmap(gather, (1, None))(a, idx)
self.assertTrue(mx.array_equal(out, mx.array([1, 2])))
idx = mx.array([0, 1])
out = mx.vmap(gather, (0, 0))(a, idx)
self.assertTrue(mx.array_equal(out, mx.array([1, 4])))
a = mx.ones((2, 3, 4))
idx = mx.zeros(4, mx.int32)
out = mx.vmap(gather, (2, 0))(a, idx)
self.assertEqual(out.shape, (4, 3))
f = mx.vmap(gather, (0, None))
f = mx.vmap(gather, (0, 0))
out = f(mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32))
self.assertEqual(out.shape, (2, 4))
def gather(a, idxa, idxb):
return a[idxa, idxb]
a = mx.ones((2, 3, 4))
idxa = mx.zeros((2, 3), mx.int32)
idxb = mx.zeros(3, mx.int32)
out = mx.vmap(gather, (0, 0, None))(a, idxa, idxb)
self.assertEqual(out.shape, (2, 3))
idxa = mx.zeros((3, 1, 2), mx.int32)
idxb = mx.zeros((2, 3, 1, 2), mx.int32)
out = mx.vmap(gather, (0, None, 0))(a, idxa, idxb)
self.assertEqual(out.shape, (2, 3, 1, 2))
idxa = mx.zeros((3, 1, 2), mx.int32)
idxb = mx.zeros((3, 1, 2, 2), mx.int32)
out = mx.vmap(gather, (0, None, 3))(a, idxa, idxb)
self.assertEqual(out.shape, (2, 3, 1, 2))
def test_vmap_scatter(self):
def scatter(a):
a[mx.array(0)] = mx.array(0.0)

View File

@@ -165,7 +165,7 @@ if __name__ == "__main__":
setup(
name="mlx",
version=get_version("0.19.2"),
version=get_version("0.20.0"),
author="MLX Contributors",
author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.",