Gather mm new kernel and small refactoring (#2040)

This commit is contained in:
Angelos Katharopoulos 2025-04-14 16:37:36 -07:00 committed by GitHub
parent e9e268336b
commit 99eefd2ec0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1260 additions and 378 deletions

View File

@ -0,0 +1,74 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@ -1,6 +1,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <cassert> #include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
} }
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out); broadcast(inputs[0], out);
} }

View File

@ -61,6 +61,7 @@ if(MLX_METAL_JIT)
kernels/steel/gemm/transforms.h) kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv

View File

@ -33,6 +33,7 @@ const char* gemm();
const char* steel_gemm_fused(); const char* steel_gemm_fused();
const char* steel_gemm_masked(); const char* steel_gemm_masked();
const char* steel_gemm_splitk(); const char* steel_gemm_splitk();
const char* steel_gemm_gather();
const char* conv(); const char* conv();
const char* steel_conv(); const char* steel_conv();
const char* steel_conv_general(); const char* steel_conv_general();

View File

@ -584,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::steel_gemm_gather(),
get_template_definition(
lib_name,
rhs ? "gather_mm_rhs" : "gather_mm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@ -160,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned, bool mn_aligned,
bool k_aligned); bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool rhs);
MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@ -69,6 +69,7 @@ set(STEEL_HEADERS
steel/gemm/loader.h steel/gemm/loader.h
steel/gemm/transforms.h steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_gather.h
steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h steel/utils/type_traits.h
@ -116,6 +117,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h) build_kernel(gemv_masked steel/utils.h)

View File

@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]]; constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]]; constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off // clang-format off
template < template <
typename T, typename T,
@ -39,12 +35,6 @@ template <
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@ -81,84 +71,26 @@ template <
} }
// Adjust for batch // Adjust for batch
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
// Handle gather ulong2 batch_offsets = elem_to_loc_broadcast(
if (do_gather) { tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) { A += batch_offsets.x;
const constant auto* indx_A_bstrides = batch_strides; B += batch_offsets.y;
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant auto* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant auto* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) { if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z; const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
} }
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
} if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
// Handle regular batch
else {
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
} }
} }

View File

@ -0,0 +1,459 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
constant bool has_batch [[function_constant(10)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* rhs_indices [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = rhs_indices[c_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (rhs_indices[c_row + n] != index) {
offset_next = n;
index_next = rhs_indices[c_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(
B + index * params->batch_stride_b,
params->ldb,
Bs,
simd_group_id,
simd_lane_id);
// Prepare iterations
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
if (offset_next - offset == BM) {
mma_op.store_result(C, params->ldd);
} else {
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_slice(
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* lhs_indices [[buffer(2)]],
const device uint32_t* rhs_indices [[buffer(3)]],
device T* C [[buffer(4)]],
const constant GEMMParams* params [[buffer(5)]],
const constant int* indices_shape [[buffer(6)]],
const constant int64_t* lhs_strides [[buffer(7)]],
const constant int64_t* rhs_strides [[buffer(8)]],
const constant int& batch_ndim_a [[buffer(9)]],
const constant int* batch_shape_a [[buffer(10)]],
const constant int64_t* batch_strides_a [[buffer(11)]],
const constant int& batch_ndim_b [[buffer(12)]],
const constant int* batch_shape_b [[buffer(13)]],
const constant int64_t* batch_strides_b [[buffer(14)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
uint32_t indx_A, indx_B;
if (has_batch) {
ulong2 indices_offsets = elem_to_loc_broadcast(
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
indx_A = lhs_indices[indices_offsets.x];
indx_B = rhs_indices[indices_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
C += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Just make sure everybody's finished with the indexing math above.
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
// Matrix level aligned never check
if (align_M && align_N) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Store results to device memory
mma_op.store_result(C, params->ldd);
} else {
const short lbk = 0;
// Tile aligned don't check
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, true, true>{});
mma_op.store_result(C, params->ldd);
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, true, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<true, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Nothing aligned so check both rows and cols
else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
lbk,
LoopAlignment<false, false, true>{});
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@ -0,0 +1,59 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm_rhs, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format on
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gather_mm_shapes_helper(float32, float, float32, float);

View File

@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
} }
} }
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename StartX,
typename StopX,
typename StartY,
typename StopY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_slice(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
StartX start_x,
StopX stop_x,
StartY start_y,
StopY stop_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
(off_y + j) < stop_y && (off_y + j) >= start_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma( METAL_FUNC static constexpr void mma(
thread frag_type& D, thread frag_type& D,
thread frag_type& A, thread frag_type& A,
@ -335,6 +371,31 @@ struct MMATile {
} }
} }
} }
template <typename U, int w_x, int w_y>
METAL_FUNC void store_slice(
device U* dst,
const int ld,
const short2 start,
const short2 stop) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_slice(
frag_at(i, j),
dst,
ld,
Int<1>{},
start.y,
stop.y,
start.x,
stop.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
}; };
template <typename T, typename U, int M, int N, int K> template <typename T, typename U, int M, int N, int K>
@ -474,6 +535,26 @@ struct BlockMMA {
Ctile.template store<U, WM, WN>(D, ldd); Ctile.template store<U, WM, WN>(D, ldd);
} }
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
// TODO: Check the start as well
if (stop.y <= 0 || stop.x <= 0) {
return;
}
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
}
METAL_FUNC void METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue // Apply epilogue

View File

@ -5,6 +5,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
} }
}; };
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
return x;
}
}
inline std::tuple<bool, int64_t, array>
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (x.flags().row_contiguous) {
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
}
bool rc = true;
for (int i = 0; i < x.ndim() - 3; i++) {
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
}
if (rc) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
auto K = x.shape(-2);
auto N = x.shape(-1);
if (sty == 1 && (N != 1 || stx == N)) {
return std::make_tuple(false, stx, x);
}
if (stx == 1 && (N != 1 || sty == K)) {
return std::make_tuple(true, sty, x);
}
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
} // namespace } // namespace
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -230,7 +272,6 @@ void steel_matmul_regular(
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@ -239,7 +280,6 @@ void steel_matmul_regular(
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@ -248,8 +288,7 @@ void steel_matmul_regular(
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // clang-format off
@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< "_do_axpby_" << (do_axpby ? 't' : 'n') << "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str(); std::string hash_name = kname.str();
@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) { void gather_mm_rhs(
using namespace mlx::steel; const array& a_,
// assert(inputs.size() == 2); const array& b_,
if (!issubdtype(out.dtype(), floating)) { const array& indices_,
throw std::runtime_error( array& out,
"[GatherMM] Does not yet support non-floating point types."); metal::Device& d,
} const Stream& s) {
auto& s = stream(); array indices = ensure_row_contiguous(indices_, d, s);
auto& d = metal::device(s.device); auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
auto& a_pre = inputs[0]; // Broadcast a with indices. If we are here that means lhs_indices were not
auto& b_pre = inputs[1]; // provided so the lhs_indices are implied to be the shape of a broadcasted
// Return 0s if either input is empty // with rhs_indices. We need only broadcast a and copy it as if applying the
if (a_pre.size() == 0 || b_pre.size() == 0) { // lhs_indices.
array zero = array(0, a_pre.dtype()); auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
fill_gpu(zero, out, s); if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
d.add_temporary(std::move(zero), s.index); return ensure_row_contiguous(x, d, s);
return; }
}
out.set_data(allocator::malloc(out.nbytes())); auto x_shape = indices.shape();
x_shape.push_back(x.shape(-2));
x_shape.push_back(x.shape(-1));
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
broadcast(x, new_x);
return ensure_row_contiguous(new_x, d, s);
};
array a = broadcast_with_indices(a_);
///////////////////////////////////////////////////////////////////////////// // Extract the matmul shapes
// Init checks and prep int K = a.shape(-1);
int M = a.size() / K;
int N = b.shape(-1);
int lda = a.strides()[a.ndim() - 2]; // should be K
int M = a_pre.shape(-2); // Define the dispatch blocks
int N = b_pre.shape(-1); int bm = 16, bn = 64, bk = 16;
int K = a_pre.shape(-1); int wm = 1, wn = 2;
// Keep a vector with copies to be cleared in the completed buffer to release const bool align_M = (M % bm) == 0;
// the arrays const bool align_N = (N % bn) == 0;
std::vector<array> copies; const bool align_K = (K % bk) == 0;
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
int lda = a_cols; // Define the kernel name
int ldb = b_cols; std::string base_name;
base_name.reserve(64);
concatenate(
base_name,
"steel_gather_mm_rhs_n",
transpose_b ? 't' : 'n',
'_',
type_to_name(a),
'_',
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
///////////////////////////////////////////////////////////////////////////// metal::MTLFCList func_consts = {
// Check and collapse batch dimensions {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
auto get_batch_dims = [](const auto& v) { {&align_K, MTL::DataType::DataTypeBool, 202},
return decltype(v){v.begin(), v.end() - 2};
}; };
auto& lhs_indices = inputs[2]; // And the kernel hash that includes the function constants
auto& rhs_indices = inputs[3]; std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
Shape batch_shape = get_batch_dims(out.shape()); // Get and set the kernel
Strides batch_strides; auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
func_consts,
out,
false,
transpose_b,
bm,
bn,
bk,
wm,
wn,
true);
compute_encoder.set_compute_pipeline_state(kernel);
batch_strides.insert( // Prepare the matmul params
batch_strides.end(), auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
lhs_indices.strides().begin(), steel::GEMMParams params{
lhs_indices.strides().end()); /* const int M = */ M,
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); /* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
/* const int64_t batch_stride_d = */ 0,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 0};
batch_strides.insert( // Prepare the grid
batch_strides.end(), MTL::Size group_dims = MTL::Size(32, wn, wm);
rhs_indices.strides().begin(), MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
rhs_indices.strides().end());
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size(); // Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(indices, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
if (batch_ndim == 0) { compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
batch_shape = {1}; }
batch_strides = {0};
}
int batch_ndim_A = a.ndim() - 2; void gather_mv(
int batch_ndim_B = b.ndim() - 2; const array& mat_,
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; const array& vec_,
const array& mat_indices_,
const array& vec_indices_,
array& out,
int N,
int K,
bool is_mv,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_mat, mat_cols, mat] =
check_transpose(copies, s, mat_, N == 1);
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
d.add_temporaries(std::move(copies), s.index);
Shape batch_shape_A = get_batch_dims(a.shape()); // If we are doing vector matrix instead of matrix vector we need to flip the
Strides batch_strides_A = get_batch_dims(a.strides()); // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
Shape batch_shape_B = get_batch_dims(b.shape()); // as a one dimensional array.
Strides batch_strides_B = get_batch_dims(b.strides()); transpose_mat = (!is_mv) ^ transpose_mat;
if (batch_ndim_A == 0) { // Define some shapes
batch_shape_A = {1}; int in_vector_len = K;
batch_strides_A = {0}; int out_vector_len = N;
} int mat_ld = mat_cols;
if (batch_ndim_B == 0) { int batch_size_out = out.size() / N;
batch_shape_B = {1}; int batch_ndim = out.ndim() - 2;
batch_strides_B = {0}; int batch_ndim_mat = mat.ndim() - 2;
} int batch_ndim_vec = vec.ndim() - 2;
Strides index_strides = vec_indices_.strides();
index_strides.insert(
index_strides.end(),
mat_indices_.strides().begin(),
mat_indices_.strides().end());
auto matrix_stride_out = static_cast<int64_t>(M) * N; // Determine dispatch kernel
auto batch_size_out = out.size() / matrix_stride_out; int tm = 4, tn = 4;
int sm = 1, sn = 32;
///////////////////////////////////////////////////////////////////////////// int bm = 1, bn = 1;
// Gemv specialization int n_out_per_tgp;
std::ostringstream kname;
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int sm = 1, sn = 32;
int bm = 1, bn = 1;
int n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else {
sm = 8;
sn = 4;
}
if (out_vector_len >= 2048) {
bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * sn * tn;
kname << "gemv_t_gather_" << type_to_name(out);
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else { } else {
bm = out_vector_len >= 4096 ? 8 : 4; sm = 8;
sn = 32; sn = 4;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * sm * tm;
kname << "gemv_gather_" << type_to_name(out);
} }
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" if (out_vector_len >= 2048) {
<< tm << "_tn" << tn; bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Encode and dispatch kernel // Specialized kernel for very small outputs
auto& compute_encoder = d.get_command_encoder(s.index); tn = out_vector_len < tn ? 1 : tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; n_out_per_tgp = bn * sn * tn;
MTL::Size group_dims = MTL::Size(32, bn, bm); kname << "gemv_t_gather_" << type_to_name(out);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0); } else {
compute_encoder.set_input_array(vec, 1); bm = out_vector_len >= 4096 ? 8 : 4;
compute_encoder.set_output_array(out, 3); sn = 32;
compute_encoder.set_bytes(in_vector_len, 4); // Specialized kernel for very small outputs
compute_encoder.set_bytes(out_vector_len, 5); tm = out_vector_len < tm ? 1 : tm;
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9); n_out_per_tgp = bm * sm * tm;
compute_encoder.set_vector_bytes(batch_shape, 10); kname << "gemv_gather_" << type_to_name(out);
compute_encoder.set_vector_bytes(batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
} }
///////////////////////////////////////////////////////////////////////////// kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
// Regular kernel dispatch << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(out.shape(), 10);
compute_encoder.set_vector_bytes(index_strides, 11);
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(vec.shape(), 13);
compute_encoder.set_vector_bytes(vec.strides(), 14);
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(mat.shape(), 16);
compute_encoder.set_vector_bytes(mat.strides(), 17);
compute_encoder.set_input_array(vec_indices_, 18);
compute_encoder.set_input_array(mat_indices_, 19);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void gather_mm(
const array& a_,
const array& b_,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
d.add_temporaries(std::move(copies), s.index);
// Determine dispatch kernel // Determine dispatch kernel
int bm = 64, bn = 64, bk = 16; int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2; int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
int batch_ndim = out.ndim() - 2;
int batch_ndim_a = a.ndim() - 2;
int batch_ndim_b = b.ndim() - 2;
char devc = d.get_architecture().back(); char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc) GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = batch_ndim > 1; const bool has_batch = batch_ndim > 1;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
const bool do_gather = true;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_gather_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10}, {&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200}, {&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201}, {&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202}, {&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
}; };
// clang-format off // And the kernel hash that includes the function constants
kname << "_has_batch_" << (has_batch ? 't' : 'n') std::string hash_name;
<< "_use_out_source_" << (use_out_source ? 't' : 'n') hash_name.reserve(128);
<< "_do_axpby_" << (do_axpby ? 't' : 'n') concatenate(
<< "_align_M_" << (align_M ? 't' : 'n') hash_name,
<< "_align_N_" << (align_N ? 't' : 'n') base_name,
<< "_align_K_" << (align_K ? 't' : 'n') "_has_batch_",
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on has_batch ? 't' : 'n',
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
std::string hash_name = kname.str(); // Get and set the kernel
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel( auto kernel = get_steel_gemm_gather_kernel(
d, d,
base_name, base_name,
hash_name, hash_name,
@ -1736,72 +1842,97 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
bn, bn,
bk, bk,
wm, wm,
wn); wn,
false);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle // Prepare the matmul params
int tn = (N + bn - 1) / bn; steel::GEMMParams params{
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M, /* const int M = */ M,
/* const int N = */ N, /* const int N = */ N,
/* const int K = */ K, /* const int K = */ K,
/* const int lda = */ lda, /* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ ldb, /* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N, /* const int ldd = */ N,
/* const int tiles_n = */ tn, /* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ tm, /* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ lhs_indices_str, /* const int64_t batch_stride_a = */
/* const int64_t batch_stride_b = */ rhs_indices_str, (batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ matrix_stride_out, /* const int64_t batch_stride_b = */
/* const int swizzle_log = */ swizzle_log, (batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ (K / bk), /* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim}; /* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params // Prepare the grid
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel // Launch kernel
compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_input_array(lhs_indices, 2);
compute_encoder.set_input_array(rhs_indices, 3);
compute_encoder.set_bytes(params, 4); compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(params, 5);
compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
compute_encoder.set_input_array(lhs_indices, 10); compute_encoder.set_bytes(batch_ndim_a, 9);
compute_encoder.set_input_array(rhs_indices, 11); compute_encoder.set_vector_bytes(a.shape(), 10);
compute_encoder.set_vector_bytes(a.strides(), 11);
std::vector operand_shape = batch_shape_A; compute_encoder.set_bytes(batch_ndim_b, 12);
operand_shape.insert( compute_encoder.set_vector_bytes(b.shape(), 13);
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end()); compute_encoder.set_vector_bytes(b.strides(), 14);
std::vector operand_strides = batch_strides_A;
operand_strides.insert(
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
operand_batch_ndim.push_back(0);
compute_encoder.set_vector_bytes(operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index); void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
// Return 0s if either input is empty
if (a.size() == 0 || b.size() == 0) {
array zero = array(0, a.dtype());
fill_gpu(zero, out, s);
d.add_temporary(std::move(zero), s.index);
return;
}
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes strides from inputs and copy in case of non-contiguous
// vectors.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}
// Route to gather gemv if any of a or b are vectors
if (M == 1) {
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
return;
}
if (N == 1) {
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
return;
}
// Route to non specialized gather mm
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -193,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@ -2,6 +2,8 @@
#pragma once #pragma once
#include <type_traits>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label(
std::string get_primitive_string(Primitive* primitive); std::string get_primitive_string(Primitive* primitive);
template <typename T>
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
!std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
template <typename T> template <typename T>
void concatenate(std::string& acc, T first) { void concatenate(std::string& acc, T first) {
acc += first; if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
} }
template <typename T, typename... Args> template <typename T, typename... Args>
void concatenate(std::string& acc, T first, Args... args) { void concatenate(std::string& acc, T first, Args... args) {
acc += first; if constexpr (is_numeric_except_char<T>) {
acc += std::to_string(first);
} else {
acc += first;
}
concatenate(acc, args...); concatenate(acc, args...);
} }

View File

@ -4499,6 +4499,7 @@ array gather_mm(
array b, array b,
std::optional<array> lhs_indices_ /* = std::nullopt */, std::optional<array> lhs_indices_ /* = std::nullopt */,
std::optional<array> rhs_indices_ /* = std::nullopt */, std::optional<array> rhs_indices_ /* = std::nullopt */,
bool sorted_indices /* = false */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// If no indices, fall back to full matmul // If no indices, fall back to full matmul
if (!lhs_indices_ && !rhs_indices_) { if (!lhs_indices_ && !rhs_indices_) {
@ -4574,12 +4575,18 @@ array gather_mm(
out_shape.push_back(M); out_shape.push_back(M);
out_shape.push_back(N); out_shape.push_back(N);
// Caculate array // Make the output array
auto out = array( auto out = array(
std::move(out_shape), std::move(out_shape),
out_type, out_type,
std::make_shared<GatherMM>(to_stream(s)), std::make_shared<GatherMM>(
{a, b, lhs_indices, rhs_indices}); to_stream(s),
sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_),
{std::move(a),
std::move(b),
std::move(lhs_indices),
std::move(rhs_indices)});
// Remove the possibly inserted singleton dimensions // Remove the possibly inserted singleton dimensions
std::vector<int> axes; std::vector<int> axes;

View File

@ -1399,6 +1399,7 @@ array gather_mm(
array b, array b,
std::optional<array> lhs_indices = std::nullopt, std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt, std::optional<array> rhs_indices = std::nullopt,
bool sorted_indices = false,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */ /** Extract a diagonal or construct a diagonal array */

View File

@ -4895,6 +4895,8 @@ std::vector<array> GatherMM::vjp(
int N = cotan.shape(-1); int N = cotan.shape(-1);
int K = primals[0].shape(-1); int K = primals[0].shape(-1);
bool sorted = left_sorted_ || right_sorted_;
for (auto arg : argnums) { for (auto arg : argnums) {
if (arg == 0) { if (arg == 0) {
// M X N * (K X N).T -> M X K // M X N * (K X N).T -> M X K
@ -4905,7 +4907,8 @@ std::vector<array> GatherMM::vjp(
base = reshape(base, {-1, M, K}, stream()); base = reshape(base, {-1, M, K}, stream());
// g : (out_batch_shape) + (M, K) // g : (out_batch_shape) + (M, K)
auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream()); auto g =
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
g = expand_dims(g, -3, stream()); g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
@ -4920,7 +4923,8 @@ std::vector<array> GatherMM::vjp(
base = reshape(base, {-1, K, N}, stream()); base = reshape(base, {-1, K, N}, stream());
// g : (out_batch_shape) + (K, N) // g : (out_batch_shape) + (K, N)
auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream()); auto g =
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
g = expand_dims(g, -3, stream()); g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
@ -4933,6 +4937,12 @@ std::vector<array> GatherMM::vjp(
return vjps; return vjps;
} }
bool GatherMM::is_equivalent(const Primitive& other) const {
const GatherMM& g_other = static_cast<const GatherMM&>(other);
return left_sorted_ == g_other.left_sorted_ &&
right_sorted_ == g_other.right_sorted_;
}
bool BlockMaskedMM::is_equivalent(const Primitive& other) const { bool BlockMaskedMM::is_equivalent(const Primitive& other) const {
const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other); const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);
return (block_size_ == a_other.block_size_); return (block_size_ == a_other.block_size_);

View File

@ -498,7 +498,13 @@ class BlockMaskedMM : public UnaryPrimitive {
class GatherMM : public UnaryPrimitive { class GatherMM : public UnaryPrimitive {
public: public:
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {} explicit GatherMM(
Stream stream,
bool left_sorted = false,
bool right_sorted = false)
: UnaryPrimitive(stream),
left_sorted_(left_sorted),
right_sorted_(right_sorted) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -510,7 +516,14 @@ class GatherMM : public UnaryPrimitive {
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
DEFINE_PRINT(GatherMM) DEFINE_PRINT(GatherMM)
DEFINE_DEFAULT_IS_EQUIVALENT() bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(left_sorted_, right_sorted_);
}
private:
bool left_sorted_;
bool right_sorted_;
}; };
class BroadcastAxes : public UnaryPrimitive { class BroadcastAxes : public UnaryPrimitive {

View File

@ -4464,9 +4464,10 @@ void init_ops(nb::module_& m) {
"lhs_indices"_a = nb::none(), "lhs_indices"_a = nb::none(),
"rhs_indices"_a = nb::none(), "rhs_indices"_a = nb::none(),
nb::kw_only(), nb::kw_only(),
"sorted_indices"_a = false,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Matrix multiplication with matrix-level gather. Matrix multiplication with matrix-level gather.
@ -4485,11 +4486,16 @@ void init_ops(nb::module_& m) {
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices`` For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
contains indices from the range ``[0, B1 * B2 * ... * BS)`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
If only one index is passed and it is sorted, the ``sorted_indices``
flag can be passed for a possible faster implementation.
Args: Args:
a (array): Input array. a (array): Input array.
b (array): Input array. b (array): Input array.
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None`` lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None`` rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.
Returns: Returns:
array: The output array. array: The output array.

View File

@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2)) lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2)) rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
M = a.shape[-2] M = a.shape[-2]
N = b.shape[-2] N = b.shape[-1]
K = a.shape[-1] K = a.shape[-1]
a = a.reshape((-1, M, K)) a = a.reshape((-1, M, K))