mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Gather mm new kernel and small refactoring (#2040)
This commit is contained in:
parent
e9e268336b
commit
99eefd2ec0
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_mm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = x @ w1.T
|
||||
x = x @ w2.T
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_mm()
|
@ -1,6 +1,7 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
|
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/common/broadcasting.h
Normal file
11
mlx/backend/common/broadcasting.h
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out);
|
||||
|
||||
} // namespace mlx::core
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
@ -61,6 +61,7 @@ if(MLX_METAL_JIT)
|
||||
kernels/steel/gemm/transforms.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
|
@ -33,6 +33,7 @@ const char* gemm();
|
||||
const char* steel_gemm_fused();
|
||||
const char* steel_gemm_masked();
|
||||
const char* steel_gemm_splitk();
|
||||
const char* steel_gemm_gather();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
|
@ -584,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool rhs) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::steel_gemm_gather(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
rhs ? "gather_mm_rhs" : "gather_mm",
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose_a,
|
||||
transpose_b));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -160,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
bool mn_aligned,
|
||||
bool k_aligned);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool rhs);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -69,6 +69,7 @@ set(STEEL_HEADERS
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_gather.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
steel/utils/type_traits.h
|
||||
@ -116,6 +117,7 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
||||
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
|
@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
constant bool do_gather [[function_constant(300)]];
|
||||
|
||||
constant bool gather_bias = do_gather && use_out_source;
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
@ -39,12 +35,6 @@ template <
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@ -81,84 +71,26 @@ template <
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
if (has_batch) {
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
// Handle gather
|
||||
if (do_gather) {
|
||||
// Read indices
|
||||
uint32_t indx_A, indx_B, indx_C;
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
if (has_batch) {
|
||||
const constant auto* indx_A_bstrides = batch_strides;
|
||||
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 indx_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
indx_A_bstrides,
|
||||
indx_B_bstrides,
|
||||
params->batch_ndim);
|
||||
indx_A = lhs_indices[indx_offsets.x];
|
||||
indx_B = rhs_indices[indx_offsets.y];
|
||||
|
||||
if (use_out_source) {
|
||||
const constant auto* indx_C_bstrides =
|
||||
indx_B_bstrides + params->batch_ndim;
|
||||
auto indx_offset_C = elem_to_loc(
|
||||
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
||||
indx_C = C_indices[indx_offset_C];
|
||||
}
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
|
||||
if (use_out_source) {
|
||||
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
|
||||
}
|
||||
}
|
||||
|
||||
// Translate indices to offsets
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
const constant int* batch_shape_A = operand_shape;
|
||||
const constant auto* batch_strides_A = operand_strides;
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
||||
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
int batch_ndim_C = operand_batch_ndim.z;
|
||||
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
||||
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
||||
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
|
||||
}
|
||||
|
||||
// Handle regular batch
|
||||
else {
|
||||
if (has_batch) {
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
|
||||
if (use_out_source) {
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
if (use_out_source) {
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
}
|
||||
|
||||
|
459
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h
Normal file
459
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h
Normal file
@ -0,0 +1,459 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
constant bool has_batch [[function_constant(10)]];
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device uint32_t* rhs_indices [[buffer(2)]],
|
||||
device T* C [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Find the block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Do as many matmuls as necessary
|
||||
uint32_t index;
|
||||
short offset;
|
||||
uint32_t index_next = rhs_indices[c_row];
|
||||
short offset_next = 0;
|
||||
int n = 0;
|
||||
while (n < tgp_bm) {
|
||||
n++;
|
||||
offset = offset_next;
|
||||
index = index_next;
|
||||
offset_next = tgp_bm;
|
||||
for (; n < tgp_bm; n++) {
|
||||
if (rhs_indices[c_row + n] != index) {
|
||||
offset_next = n;
|
||||
index_next = rhs_indices[c_row + n];
|
||||
break;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(
|
||||
B + index * params->batch_stride_b,
|
||||
params->ldb,
|
||||
Bs,
|
||||
simd_group_id,
|
||||
simd_lane_id);
|
||||
|
||||
// Prepare iterations
|
||||
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
// Do unaligned K iterations first
|
||||
if (!align_K) {
|
||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||
const int k_remain = params->K - k_last;
|
||||
const size_t k_jump_a =
|
||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||
const size_t k_jump_b =
|
||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||
|
||||
// Move loader source ahead to end
|
||||
loader_a.src += k_jump_a;
|
||||
loader_b.src += k_jump_b;
|
||||
|
||||
// Load tile
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Reset source back to start
|
||||
loader_a.src -= k_jump_a;
|
||||
loader_b.src -= k_jump_b;
|
||||
}
|
||||
|
||||
// Matrix level aligned never check
|
||||
if (align_M && align_N) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
if (offset_next - offset == BM) {
|
||||
mma_op.store_result(C, params->ldd);
|
||||
} else {
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
} else {
|
||||
const short lbk = 0;
|
||||
|
||||
// Tile aligned don't check
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
if (offset_next - offset == BM) {
|
||||
mma_op.store_result(C, params->ldd);
|
||||
} else {
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
}
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
|
||||
}
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
mma_op.store_result_slice(
|
||||
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device uint32_t* lhs_indices [[buffer(2)]],
|
||||
const device uint32_t* rhs_indices [[buffer(3)]],
|
||||
device T* C [[buffer(4)]],
|
||||
const constant GEMMParams* params [[buffer(5)]],
|
||||
const constant int* indices_shape [[buffer(6)]],
|
||||
const constant int64_t* lhs_strides [[buffer(7)]],
|
||||
const constant int64_t* rhs_strides [[buffer(8)]],
|
||||
const constant int& batch_ndim_a [[buffer(9)]],
|
||||
const constant int* batch_shape_a [[buffer(10)]],
|
||||
const constant int64_t* batch_strides_a [[buffer(11)]],
|
||||
const constant int& batch_ndim_b [[buffer(12)]],
|
||||
const constant int* batch_shape_b [[buffer(13)]],
|
||||
const constant int64_t* batch_strides_b [[buffer(14)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
|
||||
uint32_t indx_A, indx_B;
|
||||
if (has_batch) {
|
||||
ulong2 indices_offsets = elem_to_loc_broadcast(
|
||||
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
|
||||
indx_A = lhs_indices[indices_offsets.x];
|
||||
indx_B = rhs_indices[indices_offsets.y];
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
}
|
||||
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
|
||||
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
|
||||
C += params->batch_stride_d * tid.z;
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Just make sure everybody's finished with the indexing math above.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
// Prepare iterations
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
// Do unaligned K iterations first
|
||||
if (!align_K) {
|
||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||
const int k_remain = params->K - k_last;
|
||||
const size_t k_jump_a =
|
||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||
const size_t k_jump_b =
|
||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||
|
||||
// Move loader source ahead to end
|
||||
loader_a.src += k_jump_a;
|
||||
loader_b.src += k_jump_b;
|
||||
|
||||
// Load tile
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Reset source back to start
|
||||
loader_a.src -= k_jump_a;
|
||||
loader_b.src -= k_jump_b;
|
||||
}
|
||||
|
||||
// Matrix level aligned never check
|
||||
if (align_M && align_N) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, params->ldd);
|
||||
} else {
|
||||
const short lbk = 0;
|
||||
|
||||
// Tile aligned don't check
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
mma_op.store_result(C, params->ldd);
|
||||
}
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
lbk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,59 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h"
|
||||
|
||||
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gather_mm_rhs, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
float)
|
||||
|
||||
#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gather_mm, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
float)
|
||||
|
||||
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||
instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
// clang-format on
|
||||
|
||||
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
instantiate_gather_mm_shapes_helper(float32, float, float32, float);
|
@ -142,6 +142,42 @@ struct BaseMMAFrag<T, 8, 8> {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename DstPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename StartX,
|
||||
typename StopX,
|
||||
typename StartY,
|
||||
typename StopY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void store_slice(
|
||||
const thread frag_type& src,
|
||||
DstPtrType dst,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
StartX start_x,
|
||||
StopX stop_x,
|
||||
StartY start_y,
|
||||
StopY stop_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
|
||||
(off_y + j) < stop_y && (off_y + j) >= start_y) {
|
||||
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
||||
static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread frag_type& D,
|
||||
thread frag_type& A,
|
||||
@ -335,6 +371,31 @@ struct MMATile {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void store_slice(
|
||||
device U* dst,
|
||||
const int ld,
|
||||
const short2 start,
|
||||
const short2 stop) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store_slice(
|
||||
frag_at(i, j),
|
||||
dst,
|
||||
ld,
|
||||
Int<1>{},
|
||||
start.y,
|
||||
stop.y,
|
||||
start.x,
|
||||
stop.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int M, int N, int K>
|
||||
@ -474,6 +535,26 @@ struct BlockMMA {
|
||||
Ctile.template store<U, WM, WN>(D, ldd);
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
D += sm * ldd + sn;
|
||||
start -= short2(sn, sm);
|
||||
stop -= short2(sn, sm);
|
||||
|
||||
// TODO: Check the start as well
|
||||
if (stop.y <= 0 || stop.x <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||
// Apply epilogue
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@ -102,6 +103,47 @@ std::tuple<bool, int64_t, array> check_transpose(
|
||||
}
|
||||
};
|
||||
|
||||
inline array
|
||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
inline std::tuple<bool, int64_t, array>
|
||||
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
|
||||
}
|
||||
|
||||
bool rc = true;
|
||||
for (int i = 0; i < x.ndim() - 3; i++) {
|
||||
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
|
||||
}
|
||||
if (rc) {
|
||||
auto stx = x.strides()[x.ndim() - 2];
|
||||
auto sty = x.strides()[x.ndim() - 1];
|
||||
auto K = x.shape(-2);
|
||||
auto N = x.shape(-1);
|
||||
if (sty == 1 && (N != 1 || stx == N)) {
|
||||
return std::make_tuple(false, stx, x);
|
||||
}
|
||||
if (stx == 1 && (N != 1 || sty == K)) {
|
||||
return std::make_tuple(true, sty, x);
|
||||
}
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -230,7 +272,6 @@ void steel_matmul_regular(
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = false;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
@ -239,7 +280,6 @@ void steel_matmul_regular(
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
@ -248,8 +288,7 @@ void steel_matmul_regular(
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = false;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
using namespace mlx::steel;
|
||||
// assert(inputs.size() == 2);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[GatherMM] Does not yet support non-floating point types.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
void gather_mm_rhs(
|
||||
const array& a_,
|
||||
const array& b_,
|
||||
const array& indices_,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
array indices = ensure_row_contiguous(indices_, d, s);
|
||||
auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
// Return 0s if either input is empty
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
fill_gpu(zero, out, s);
|
||||
d.add_temporary(std::move(zero), s.index);
|
||||
return;
|
||||
}
|
||||
// Broadcast a with indices. If we are here that means lhs_indices were not
|
||||
// provided so the lhs_indices are implied to be the shape of a broadcasted
|
||||
// with rhs_indices. We need only broadcast a and copy it as if applying the
|
||||
// lhs_indices.
|
||||
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
|
||||
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
|
||||
return ensure_row_contiguous(x, d, s);
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto x_shape = indices.shape();
|
||||
x_shape.push_back(x.shape(-2));
|
||||
x_shape.push_back(x.shape(-1));
|
||||
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
|
||||
broadcast(x, new_x);
|
||||
return ensure_row_contiguous(new_x, d, s);
|
||||
};
|
||||
array a = broadcast_with_indices(a_);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
// Extract the matmul shapes
|
||||
int K = a.shape(-1);
|
||||
int M = a.size() / K;
|
||||
int N = b.shape(-1);
|
||||
int lda = a.strides()[a.ndim() - 2]; // should be K
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
// Define the dispatch blocks
|
||||
int bm = 16, bn = 64, bk = 16;
|
||||
int wm = 1, wn = 2;
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
// Define the kernel name
|
||||
std::string base_name;
|
||||
base_name.reserve(64);
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_gather_mm_rhs_n",
|
||||
transpose_b ? 't' : 'n',
|
||||
'_',
|
||||
type_to_name(a),
|
||||
'_',
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
};
|
||||
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
// And the kernel hash that includes the function constants
|
||||
std::string hash_name;
|
||||
hash_name.reserve(128);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
align_N ? 't' : 'n',
|
||||
"_align_K_",
|
||||
align_K ? 't' : 'n');
|
||||
|
||||
Shape batch_shape = get_batch_dims(out.shape());
|
||||
Strides batch_strides;
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_gather_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
false,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
true);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
// Prepare the matmul params
|
||||
auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
|
||||
steel::GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ static_cast<int>(ldb),
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||
/* const int64_t batch_stride_a = */ 0,
|
||||
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
|
||||
/* const int64_t batch_stride_d = */ 0,
|
||||
/* const int swizzle_log = */ 0,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ 0};
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
rhs_indices.strides().begin(),
|
||||
rhs_indices.strides().end());
|
||||
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
// Prepare the grid
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(indices, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
if (batch_ndim == 0) {
|
||||
batch_shape = {1};
|
||||
batch_strides = {0};
|
||||
}
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
int batch_ndim_A = a.ndim() - 2;
|
||||
int batch_ndim_B = b.ndim() - 2;
|
||||
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
|
||||
void gather_mv(
|
||||
const array& mat_,
|
||||
const array& vec_,
|
||||
const array& mat_indices_,
|
||||
const array& vec_indices_,
|
||||
array& out,
|
||||
int N,
|
||||
int K,
|
||||
bool is_mv,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Copy if needed
|
||||
std::vector<array> copies;
|
||||
auto [transpose_mat, mat_cols, mat] =
|
||||
check_transpose(copies, s, mat_, N == 1);
|
||||
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
Shape batch_shape_A = get_batch_dims(a.shape());
|
||||
Strides batch_strides_A = get_batch_dims(a.strides());
|
||||
Shape batch_shape_B = get_batch_dims(b.shape());
|
||||
Strides batch_strides_B = get_batch_dims(b.strides());
|
||||
// If we are doing vector matrix instead of matrix vector we need to flip the
|
||||
// matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
|
||||
// as a one dimensional array.
|
||||
transpose_mat = (!is_mv) ^ transpose_mat;
|
||||
|
||||
if (batch_ndim_A == 0) {
|
||||
batch_shape_A = {1};
|
||||
batch_strides_A = {0};
|
||||
}
|
||||
// Define some shapes
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = N;
|
||||
int mat_ld = mat_cols;
|
||||
|
||||
if (batch_ndim_B == 0) {
|
||||
batch_shape_B = {1};
|
||||
batch_strides_B = {0};
|
||||
}
|
||||
int batch_size_out = out.size() / N;
|
||||
int batch_ndim = out.ndim() - 2;
|
||||
int batch_ndim_mat = mat.ndim() - 2;
|
||||
int batch_ndim_vec = vec.ndim() - 2;
|
||||
Strides index_strides = vec_indices_.strides();
|
||||
index_strides.insert(
|
||||
index_strides.end(),
|
||||
mat_indices_.strides().begin(),
|
||||
mat_indices_.strides().end());
|
||||
|
||||
auto matrix_stride_out = static_cast<int64_t>(M) * N;
|
||||
auto batch_size_out = out.size() / matrix_stride_out;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
|
||||
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
|
||||
|
||||
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
|
||||
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
|
||||
|
||||
if (!is_b_matrix) {
|
||||
batch_strides = rhs_indices.strides();
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
}
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_gather_" << type_to_name(out);
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_gather_" << type_to_name(out);
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_gather_" << type_to_name(out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
sn = 32;
|
||||
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 11);
|
||||
|
||||
int batch_ndim_vec = batch_shape_vec.size();
|
||||
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
||||
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
|
||||
|
||||
int batch_ndim_mat = batch_shape_mat.size();
|
||||
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
||||
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
|
||||
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_gather_" << type_to_name(out);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(out.shape(), 10);
|
||||
compute_encoder.set_vector_bytes(index_strides, 11);
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
||||
compute_encoder.set_vector_bytes(vec.shape(), 13);
|
||||
compute_encoder.set_vector_bytes(vec.strides(), 14);
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
||||
compute_encoder.set_vector_bytes(mat.shape(), 16);
|
||||
compute_encoder.set_vector_bytes(mat.strides(), 17);
|
||||
|
||||
compute_encoder.set_input_array(vec_indices_, 18);
|
||||
compute_encoder.set_input_array(mat_indices_, 19);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void gather_mm(
|
||||
const array& a_,
|
||||
const array& b_,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Copy if needed
|
||||
std::vector<array> copies;
|
||||
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
||||
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
size_t batch_size_out = out.size() / M / N;
|
||||
int batch_ndim = out.ndim() - 2;
|
||||
int batch_ndim_a = a.ndim() - 2;
|
||||
int batch_ndim_b = b.ndim() - 2;
|
||||
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn;
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = batch_ndim > 1;
|
||||
const bool use_out_source = false;
|
||||
const bool do_axpby = false;
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = true;
|
||||
|
||||
// Define the kernel name
|
||||
std::string base_name;
|
||||
base_name.reserve(128);
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_gather_mm_",
|
||||
transpose_a ? 't' : 'n',
|
||||
transpose_b ? 't' : 'n',
|
||||
"_",
|
||||
type_to_name(a),
|
||||
"_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
// And the kernel hash that includes the function constants
|
||||
std::string hash_name;
|
||||
hash_name.reserve(128);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_has_batch_",
|
||||
has_batch ? 't' : 'n',
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
align_N ? 't' : 'n',
|
||||
"_align_K_",
|
||||
align_K ? 't' : 'n');
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
// Encode and dispatch kernel
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
auto kernel = get_steel_gemm_gather_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
@ -1736,72 +1842,97 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
wn,
|
||||
false);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
// Prepare the matmul params
|
||||
steel::GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int lda = */ static_cast<int>(lda),
|
||||
/* const int ldb = */ static_cast<int>(ldb),
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int64_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const int64_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||
/* const int64_t batch_stride_a = */
|
||||
(batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
|
||||
/* const int64_t batch_stride_b = */
|
||||
(batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
|
||||
/* const int64_t batch_stride_d = */ M * N,
|
||||
/* const int swizzle_log = */ 0,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ batch_ndim};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
// Prepare the grid
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
MTL::Size grid_dims =
|
||||
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 10);
|
||||
compute_encoder.set_input_array(rhs_indices, 11);
|
||||
|
||||
std::vector operand_shape = batch_shape_A;
|
||||
operand_shape.insert(
|
||||
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end());
|
||||
|
||||
std::vector operand_strides = batch_strides_A;
|
||||
operand_strides.insert(
|
||||
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
|
||||
|
||||
operand_batch_ndim.push_back(0);
|
||||
|
||||
compute_encoder.set_vector_bytes(operand_shape, 13);
|
||||
compute_encoder.set_vector_bytes(operand_strides, 14);
|
||||
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 2);
|
||||
compute_encoder.set_input_array(rhs_indices, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder.set_bytes(params, 5);
|
||||
compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
|
||||
compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
|
||||
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
|
||||
compute_encoder.set_bytes(batch_ndim_a, 9);
|
||||
compute_encoder.set_vector_bytes(a.shape(), 10);
|
||||
compute_encoder.set_vector_bytes(a.strides(), 11);
|
||||
compute_encoder.set_bytes(batch_ndim_b, 12);
|
||||
compute_encoder.set_vector_bytes(b.shape(), 13);
|
||||
compute_encoder.set_vector_bytes(b.strides(), 14);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
// Return 0s if either input is empty
|
||||
if (a.size() == 0 || b.size() == 0) {
|
||||
array zero = array(0, a.dtype());
|
||||
fill_gpu(zero, out, s);
|
||||
d.add_temporary(std::move(zero), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Extract shapes strides from inputs and copy in case of non-contiguous
|
||||
// vectors.
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
// We are walking a in order and b is also in order so we can batch up the
|
||||
// matmuls and reuse reading a and b.
|
||||
if (M == 1 && right_sorted_ == true) {
|
||||
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
||||
return;
|
||||
}
|
||||
|
||||
// Route to gather gemv if any of a or b are vectors
|
||||
if (M == 1) {
|
||||
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
|
||||
return;
|
||||
}
|
||||
if (N == 1) {
|
||||
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
|
||||
return;
|
||||
}
|
||||
|
||||
// Route to non specialized gather mm
|
||||
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -193,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array&,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label(
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive);
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
|
||||
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
|
||||
!std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
|
||||
|
||||
template <typename T>
|
||||
void concatenate(std::string& acc, T first) {
|
||||
acc += first;
|
||||
if constexpr (is_numeric_except_char<T>) {
|
||||
acc += std::to_string(first);
|
||||
} else {
|
||||
acc += first;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void concatenate(std::string& acc, T first, Args... args) {
|
||||
acc += first;
|
||||
if constexpr (is_numeric_except_char<T>) {
|
||||
acc += std::to_string(first);
|
||||
} else {
|
||||
acc += first;
|
||||
}
|
||||
concatenate(acc, args...);
|
||||
}
|
||||
|
||||
|
13
mlx/ops.cpp
13
mlx/ops.cpp
@ -4499,6 +4499,7 @@ array gather_mm(
|
||||
array b,
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||
bool sorted_indices /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// If no indices, fall back to full matmul
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
@ -4574,12 +4575,18 @@ array gather_mm(
|
||||
out_shape.push_back(M);
|
||||
out_shape.push_back(N);
|
||||
|
||||
// Caculate array
|
||||
// Make the output array
|
||||
auto out = array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<GatherMM>(to_stream(s)),
|
||||
{a, b, lhs_indices, rhs_indices});
|
||||
std::make_shared<GatherMM>(
|
||||
to_stream(s),
|
||||
sorted_indices && !rhs_indices_,
|
||||
sorted_indices && !lhs_indices_),
|
||||
{std::move(a),
|
||||
std::move(b),
|
||||
std::move(lhs_indices),
|
||||
std::move(rhs_indices)});
|
||||
|
||||
// Remove the possibly inserted singleton dimensions
|
||||
std::vector<int> axes;
|
||||
|
@ -1399,6 +1399,7 @@ array gather_mm(
|
||||
array b,
|
||||
std::optional<array> lhs_indices = std::nullopt,
|
||||
std::optional<array> rhs_indices = std::nullopt,
|
||||
bool sorted_indices = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Extract a diagonal or construct a diagonal array */
|
||||
|
@ -4895,6 +4895,8 @@ std::vector<array> GatherMM::vjp(
|
||||
int N = cotan.shape(-1);
|
||||
int K = primals[0].shape(-1);
|
||||
|
||||
bool sorted = left_sorted_ || right_sorted_;
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
@ -4905,7 +4907,8 @@ std::vector<array> GatherMM::vjp(
|
||||
base = reshape(base, {-1, M, K}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (M, K)
|
||||
auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream());
|
||||
auto g =
|
||||
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||
|
||||
@ -4920,7 +4923,8 @@ std::vector<array> GatherMM::vjp(
|
||||
base = reshape(base, {-1, K, N}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (K, N)
|
||||
auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream());
|
||||
auto g =
|
||||
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||
|
||||
@ -4933,6 +4937,12 @@ std::vector<array> GatherMM::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
bool GatherMM::is_equivalent(const Primitive& other) const {
|
||||
const GatherMM& g_other = static_cast<const GatherMM&>(other);
|
||||
return left_sorted_ == g_other.left_sorted_ &&
|
||||
right_sorted_ == g_other.right_sorted_;
|
||||
}
|
||||
|
||||
bool BlockMaskedMM::is_equivalent(const Primitive& other) const {
|
||||
const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);
|
||||
return (block_size_ == a_other.block_size_);
|
||||
|
@ -498,7 +498,13 @@ class BlockMaskedMM : public UnaryPrimitive {
|
||||
|
||||
class GatherMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {}
|
||||
explicit GatherMM(
|
||||
Stream stream,
|
||||
bool left_sorted = false,
|
||||
bool right_sorted = false)
|
||||
: UnaryPrimitive(stream),
|
||||
left_sorted_(left_sorted),
|
||||
right_sorted_(right_sorted) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -510,7 +516,14 @@ class GatherMM : public UnaryPrimitive {
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(GatherMM)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_pair(left_sorted_, right_sorted_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool left_sorted_;
|
||||
bool right_sorted_;
|
||||
};
|
||||
|
||||
class BroadcastAxes : public UnaryPrimitive {
|
||||
|
@ -4464,9 +4464,10 @@ void init_ops(nb::module_& m) {
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"sorted_indices"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix multiplication with matrix-level gather.
|
||||
|
||||
@ -4485,11 +4486,16 @@ void init_ops(nb::module_& m) {
|
||||
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
|
||||
contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
||||
|
||||
If only one index is passed and it is sorted, the ``sorted_indices``
|
||||
flag can be passed for a possible faster implementation.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
|
||||
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
|
||||
sorted_indices (bool, optional): May allow a faster implementation
|
||||
if the passed indices are sorted. Default: ``False``.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
|
@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
|
||||
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
|
||||
M = a.shape[-2]
|
||||
N = b.shape[-2]
|
||||
N = b.shape[-1]
|
||||
K = a.shape[-1]
|
||||
|
||||
a = a.reshape((-1, M, K))
|
||||
|
Loading…
Reference in New Issue
Block a user