mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 06:21:12 +08:00
MoE backward improvements (#2335)
This commit is contained in:
parent
a4fcc893cd
commit
4a9b29a875
@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -52,6 +53,58 @@ inline void mask_matrix(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void segmented_mm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
const uint32_t* segments,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides,
|
||||
size_t num_segments,
|
||||
const Shape& segments_shape,
|
||||
const Strides& segments_strides) {
|
||||
int ndim = a_shape.size();
|
||||
Shape a_copy = a_shape;
|
||||
Shape b_copy = b_shape;
|
||||
int32_t M = a_copy[ndim - 2];
|
||||
int32_t N = b_copy[ndim - 1];
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
uint32_t k_start =
|
||||
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||
uint32_t k_end =
|
||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||
if (k_end <= k_start) {
|
||||
std::fill_n(out + i * M * N, M * N, T(0));
|
||||
continue;
|
||||
}
|
||||
a_copy[ndim - 1] = k_end - k_start;
|
||||
b_copy[ndim - 2] = k_end - k_start;
|
||||
matmul<T>(
|
||||
a + k_start * a_strides[ndim - 1],
|
||||
b + k_start * b_strides[ndim - 2],
|
||||
out + i * M * N,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
1.0,
|
||||
0.0,
|
||||
1,
|
||||
a_copy,
|
||||
a_strides,
|
||||
b_copy,
|
||||
b_strides);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -437,4 +490,121 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
auto check_transpose = [&s, &encoder](const array& x) {
|
||||
auto stx = x.strides()[x.ndim() - 2];
|
||||
auto sty = x.strides()[x.ndim() - 1];
|
||||
if (stx == x.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, x);
|
||||
} else if (stx == 1 && sty == x.shape(-2)) {
|
||||
return std::make_tuple(true, sty, x);
|
||||
} else {
|
||||
array xc(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, xc, CopyType::General, s);
|
||||
encoder.add_temporary(xc);
|
||||
int64_t stx = x.shape(-1);
|
||||
return std::make_tuple(false, stx, xc);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
|
||||
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
|
||||
auto& segments = inputs[2];
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(segments);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
segments = array::unsafe_weak_copy(segments),
|
||||
out_ptr = out.data<void>(),
|
||||
a_transposed = a_transposed,
|
||||
b_transposed = b_transposed,
|
||||
lda = lda,
|
||||
ldb = ldb]() {
|
||||
switch (a.dtype()) {
|
||||
case float64:
|
||||
segmented_mm<double>(
|
||||
a.data<double>(),
|
||||
b.data<double>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<double*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case float32:
|
||||
segmented_mm<float>(
|
||||
a.data<float>(),
|
||||
b.data<float>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case float16:
|
||||
segmented_mm<float16_t>(
|
||||
a.data<float16_t>(),
|
||||
b.data<float16_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case bfloat16:
|
||||
segmented_mm<bfloat16_t>(
|
||||
a.data<bfloat16_t>(),
|
||||
b.data<bfloat16_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<bfloat16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"Segmented mm supports only real float types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -83,6 +83,7 @@ NO_GPU_MULTI(LUF)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(SegmentedMM)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
|
@ -63,6 +63,7 @@ if(MLX_METAL_JIT)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_segmented)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
kernels/steel/utils.h
|
||||
|
@ -575,9 +575,17 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||
if (ndim > 1) {
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||
} else {
|
||||
// The following will be ignored in the kernel but we still have to set
|
||||
// some value so that metal validation passes.
|
||||
compute_encoder.set_vector_bytes(idx.shape(), 3);
|
||||
compute_encoder.set_vector_bytes(upd.strides(), 4);
|
||||
compute_encoder.set_vector_bytes(idx.strides(), 5);
|
||||
}
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||
|
@ -34,6 +34,7 @@ const char* steel_gemm_fused();
|
||||
const char* steel_gemm_masked();
|
||||
const char* steel_gemm_splitk();
|
||||
const char* steel_gemm_gather();
|
||||
const char* steel_gemm_segmented();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
|
@ -652,6 +652,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::steel_gemm_segmented(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
"segmented_mm",
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose_a,
|
||||
transpose_b));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
int wn,
|
||||
bool rhs);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -71,6 +71,7 @@ set(STEEL_HEADERS
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_gather.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_segmented.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
@ -120,6 +121,7 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
endif()
|
||||
|
||||
|
@ -0,0 +1,266 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
constant bool segments_contiguous [[function_constant(199)]];
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device uint32_t* segments [[buffer(2)]],
|
||||
device T* C [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Find the block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
// Move the pointers to the output tile
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Move the pointers to the start of the segment
|
||||
uint32_t k_start, k_end;
|
||||
if (segments_contiguous) {
|
||||
k_start = segments[2 * tid.z];
|
||||
k_end = segments[2 * tid.z + 1];
|
||||
} else {
|
||||
// We accept either contiguous (above) or weird strides where the beginning
|
||||
// of the next one is the previous one. Basically the last two strides are
|
||||
// both 1!
|
||||
k_start = segments[tid.z];
|
||||
k_end = segments[tid.z + 1];
|
||||
}
|
||||
A += transpose_a ? k_start * params->lda : k_start;
|
||||
B += transpose_b ? k_start : k_start * params->ldb;
|
||||
C += tid.z * params->batch_stride_d;
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Matrix level alignment so only check K
|
||||
if (align_M && align_N) {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result(C, params->ldd);
|
||||
} else {
|
||||
// Tile aligned do the same as above
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result(C, params->ldd);
|
||||
}
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_safe(
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_safe(
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_safe(
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
||||
loader_b.load_safe(
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h"
|
||||
|
||||
#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
segmented_mm, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
float)
|
||||
|
||||
#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
// clang-format on
|
||||
|
||||
instantiate_segmented_mm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_segmented_mm_shapes_helper(
|
||||
bfloat16,
|
||||
bfloat16_t,
|
||||
bfloat16,
|
||||
bfloat16_t);
|
||||
instantiate_segmented_mm_shapes_helper(float32, float, float32, float);
|
@ -1864,4 +1864,166 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
||||
}
|
||||
|
||||
void segmented_mm(
|
||||
const array& a_,
|
||||
const array& b_,
|
||||
const array& segments_,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto check_segments_layout = [&d, &s](const array& x) {
|
||||
// Contiguous so return early
|
||||
if (x.flags().row_contiguous) {
|
||||
return std::make_tuple(true, x);
|
||||
}
|
||||
|
||||
bool rc = true;
|
||||
for (int i = 0; i < x.ndim() - 2; i++) {
|
||||
rc &=
|
||||
(x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1);
|
||||
}
|
||||
rc &= x.strides(x.ndim() - 1) == 1;
|
||||
if (x.ndim() > 1) {
|
||||
rc &= x.strides(x.ndim() - 2) == 1;
|
||||
}
|
||||
|
||||
if (rc) {
|
||||
return std::make_tuple(false, x);
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return std::make_tuple(true, x_copy);
|
||||
};
|
||||
|
||||
// Copy if needed
|
||||
std::vector<array> copies;
|
||||
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
||||
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
||||
auto [segments_contiguous, segments] = check_segments_layout(segments_);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
size_t batch_size_out = out.size() / M / N;
|
||||
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
|
||||
// Define the kernel name
|
||||
std::string base_name;
|
||||
base_name.reserve(128);
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_segmented_mm_",
|
||||
transpose_a ? 't' : 'n',
|
||||
transpose_b ? 't' : 'n',
|
||||
"_",
|
||||
type_to_name(a),
|
||||
"_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
};
|
||||
|
||||
// And the kernel hash that includes the function constants
|
||||
std::string hash_name;
|
||||
hash_name.reserve(128);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_segments_contiguous_",
|
||||
segments_contiguous ? 't' : 'n',
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
align_N ? 't' : 'n');
|
||||
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_segmented_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Prepare the matmul params
|
||||
steel::GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ static_cast<int>(lda),
|
||||
/* const int ldb = */ static_cast<int>(ldb),
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||
/* const int64_t batch_stride_a = */ 0,
|
||||
/* const int64_t batch_stride_b = */ 0,
|
||||
/* const int64_t batch_stride_d = */ M * N,
|
||||
/* const int swizzle_log = */ 0,
|
||||
/* const int gemm_k_iterations_aligned = */ 0,
|
||||
/* const int batch_ndim = */ 0};
|
||||
|
||||
// Prepare the grid
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims =
|
||||
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(segments, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& segments = inputs[2];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Extract shapes from inputs.
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
segmented_mm(a, b, segments, out, M, N, K, d, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -210,6 +210,22 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array&,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -105,6 +105,7 @@ NO_CPU(Scan)
|
||||
NO_CPU(Scatter)
|
||||
NO_CPU(ScatterAxis)
|
||||
NO_CPU(Select)
|
||||
NO_CPU(SegmentedMM)
|
||||
NO_CPU(Sigmoid)
|
||||
NO_CPU(Sign)
|
||||
NO_CPU(Sin)
|
||||
|
@ -121,6 +121,7 @@ NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(SegmentedMM)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
|
48
mlx/ops.cpp
48
mlx/ops.cpp
@ -4649,6 +4649,54 @@ array gather_mm(
|
||||
return axes.empty() ? out : squeeze(out, axes, s);
|
||||
}
|
||||
|
||||
array segmented_mm(
|
||||
array a,
|
||||
array b,
|
||||
array segments,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (a.ndim() != 2 || b.ndim() != 2) {
|
||||
throw std::invalid_argument("[segmented_mm] Batched matmul not supported");
|
||||
}
|
||||
|
||||
if (segments.ndim() < 1 || segments.shape().back() != 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[segmented_mm] The segments should have shape (..., 2) but "
|
||||
<< segments.shape() << " was provided.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[segmented_mm] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype()
|
||||
<< " were provided which results in " << out_type
|
||||
<< ", which is not a real floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (!issubdtype(segments.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[segmented_mm] Got segments with invalid dtype. Segments must be integral.");
|
||||
}
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
segments = astype(segments, uint32, s);
|
||||
|
||||
Shape out_shape = segments.shape();
|
||||
out_shape.pop_back();
|
||||
out_shape.push_back(a.shape(0));
|
||||
out_shape.push_back(b.shape(1));
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<SegmentedMM>(to_stream(s)),
|
||||
{std::move(a), std::move(b), std::move(segments)});
|
||||
}
|
||||
|
||||
array diagonal(
|
||||
const array& a,
|
||||
int offset /* = 0 */,
|
||||
|
@ -1406,6 +1406,12 @@ array gather_mm(
|
||||
bool sorted_indices = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Compute a matrix product but segment the inner dimension and write the
|
||||
* result separately for each segment.
|
||||
*/
|
||||
array segmented_mm(array a, array b, array segments, StreamOrDevice s = {});
|
||||
|
||||
/** Extract a diagonal or construct a diagonal array */
|
||||
array diagonal(
|
||||
const array& a,
|
||||
|
@ -109,6 +109,70 @@ std::tuple<array, array, array, int> vmap_ternary_op(
|
||||
return {a, b, c, to_ax};
|
||||
}
|
||||
|
||||
// Calculate the gradient wrt to the weights of the following calculation
|
||||
//
|
||||
// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted)
|
||||
//
|
||||
// Note the transpose above. This function returns the gradient for w.T so if w
|
||||
// was used instead then one needs to transpose the returned gradient.
|
||||
//
|
||||
// We define it as a separate function to reuse it for gather_mm and
|
||||
// gather_qmm.
|
||||
array gather_mm_grad(
|
||||
const array& x,
|
||||
const array& dy,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
bool sorted,
|
||||
Shape batch_shape,
|
||||
const Stream& s) {
|
||||
int M = x.shape(-2);
|
||||
int K = x.shape(-1);
|
||||
int N = dy.shape(-1);
|
||||
int num_segments = std::accumulate(
|
||||
batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int>());
|
||||
batch_shape.push_back(N);
|
||||
batch_shape.push_back(K);
|
||||
|
||||
// If the indices are sorted then it means that we can do the whole gradient
|
||||
// computation via a segmented matmul. We just need to calculate the segments
|
||||
// using the indices.
|
||||
if (sorted) {
|
||||
auto segments = zeros({num_segments}, uint32, s);
|
||||
segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s);
|
||||
segments = cumsum(segments, 0, false, true, s);
|
||||
segments = concatenate({array({0}, {1}, uint32), segments}, 0, s);
|
||||
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s);
|
||||
|
||||
return reshape(
|
||||
segmented_mm(
|
||||
swapaxes(flatten(dy, 0, -2, s), 0, 1, s),
|
||||
flatten(x, 0, -2, s),
|
||||
segments,
|
||||
s),
|
||||
std::move(batch_shape),
|
||||
s);
|
||||
}
|
||||
|
||||
// Otherwise we need to gather matmul the dy and then scatter add it to the
|
||||
// correct locations.
|
||||
else {
|
||||
// TODO: If the lhs indices wasn't provided, this is always a sorted matmul
|
||||
// so we should add that check.
|
||||
auto dw = gather_mm(
|
||||
swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s);
|
||||
return reshape(
|
||||
scatter_add(
|
||||
zeros({num_segments, N, K}, dw.dtype(), s),
|
||||
rhs_indices,
|
||||
expand_dims(dw, -3, s),
|
||||
0,
|
||||
s),
|
||||
std::move(batch_shape),
|
||||
s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<array> Primitive::jvp(
|
||||
@ -3169,8 +3233,9 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (!dsb) {
|
||||
auto fc = flatten(cotangents[0], 0, -2, stream());
|
||||
auto fx = flatten(primals[0], 0, -2, stream());
|
||||
int ndim = primals[1].ndim();
|
||||
auto fc = flatten(cotangents[0], 0, -ndim, stream());
|
||||
auto fx = flatten(primals[0], 0, -ndim, stream());
|
||||
auto dw = transpose_
|
||||
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
|
||||
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
|
||||
@ -3181,7 +3246,6 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
vjps.push_back(sum(*dsb, -1, false, stream()));
|
||||
} else {
|
||||
// scales
|
||||
auto s = stream();
|
||||
auto wq = dequantize(
|
||||
primals[1],
|
||||
ones_like(primals[2], stream()),
|
||||
@ -3253,34 +3317,42 @@ std::vector<array> GatherQMM::vjp(
|
||||
auto& lhs_indices = primals[4];
|
||||
auto& rhs_indices = primals[5];
|
||||
|
||||
int M = cotan.shape(-2);
|
||||
int N = cotan.shape(-1);
|
||||
int K = x.shape(-1);
|
||||
|
||||
bool sorted = left_sorted_ || right_sorted_;
|
||||
bool no_broadcast = rhs_indices.size() * M * K == x.size();
|
||||
std::optional<array> dsb = std::nullopt;
|
||||
|
||||
for (auto arg : argnums) {
|
||||
// gradient wrt to x
|
||||
if (arg == 0) {
|
||||
vjps.push_back(reshape(
|
||||
scatter_add(
|
||||
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||
lhs_indices,
|
||||
expand_dims(
|
||||
gather_qmm(
|
||||
cotan,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
std::nullopt,
|
||||
rhs_indices,
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
sorted,
|
||||
stream()),
|
||||
-3,
|
||||
stream()),
|
||||
0,
|
||||
stream()),
|
||||
x.shape(),
|
||||
stream()));
|
||||
auto g = gather_qmm(
|
||||
cotan,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
std::nullopt,
|
||||
rhs_indices,
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
sorted,
|
||||
stream());
|
||||
if (sorted && no_broadcast) {
|
||||
vjps.push_back(g);
|
||||
} else {
|
||||
vjps.push_back(reshape(
|
||||
scatter_add(
|
||||
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||
lhs_indices,
|
||||
expand_dims(g, -3, stream()),
|
||||
0,
|
||||
stream()),
|
||||
x.shape(),
|
||||
stream()));
|
||||
}
|
||||
}
|
||||
|
||||
// gradient wrt to the indices is undefined
|
||||
@ -3290,9 +3362,49 @@ std::vector<array> GatherQMM::vjp(
|
||||
}
|
||||
|
||||
// gradient wrt to w_q, scales or biases
|
||||
else {
|
||||
else if (arg == 1) {
|
||||
throw std::runtime_error(
|
||||
"GatherQMM::vjp no gradient wrt the quantized matrix yet.");
|
||||
"GatherQMM::vjp no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (!dsb) {
|
||||
auto shape = w.shape();
|
||||
shape.pop_back();
|
||||
shape.pop_back();
|
||||
dsb = unflatten(
|
||||
gather_mm_grad(
|
||||
x,
|
||||
cotan,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
sorted,
|
||||
std::move(shape),
|
||||
stream()),
|
||||
-1,
|
||||
{-1, group_size_},
|
||||
stream());
|
||||
}
|
||||
if (arg == 3) {
|
||||
vjps.push_back(sum(*dsb, -1, false, stream()));
|
||||
} else {
|
||||
vjps.push_back(
|
||||
sum(multiply(
|
||||
*dsb,
|
||||
unflatten(
|
||||
dequantize(
|
||||
w,
|
||||
ones_like(scales, stream()),
|
||||
zeros_like(biases, stream()),
|
||||
group_size_,
|
||||
bits_,
|
||||
stream()),
|
||||
-1,
|
||||
{-1, group_size_},
|
||||
stream()),
|
||||
stream()),
|
||||
-1,
|
||||
false,
|
||||
stream()));
|
||||
}
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
@ -5064,6 +5176,8 @@ std::vector<array> GatherMM::vjp(
|
||||
std::vector<array> vjps;
|
||||
auto& cotan = cotangents[0];
|
||||
|
||||
auto& a = primals[0];
|
||||
auto& b = primals[1];
|
||||
auto& lhs_indices = primals[2];
|
||||
auto& rhs_indices = primals[3];
|
||||
|
||||
@ -5072,39 +5186,46 @@ std::vector<array> GatherMM::vjp(
|
||||
int K = primals[0].shape(-1);
|
||||
|
||||
bool sorted = left_sorted_ || right_sorted_;
|
||||
bool no_broadcast = rhs_indices.size() * M * K == primals[0].size();
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
auto base = zeros_like(primals[0], stream());
|
||||
auto bt = swapaxes(primals[1], -1, -2, stream());
|
||||
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, M, K}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (M, K)
|
||||
auto g =
|
||||
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
|
||||
auto g = gather_mm(
|
||||
cotan,
|
||||
swapaxes(b, -1, -2, stream()),
|
||||
std::nullopt,
|
||||
rhs_indices,
|
||||
sorted,
|
||||
stream());
|
||||
if (sorted && no_broadcast) {
|
||||
vjps.push_back(g);
|
||||
} else {
|
||||
vjps.push_back(reshape(
|
||||
scatter_add(
|
||||
flatten(zeros_like(a, stream()), 0, -3, stream()),
|
||||
lhs_indices,
|
||||
expand_dims(g, -3, stream()),
|
||||
0,
|
||||
stream()),
|
||||
a.shape(),
|
||||
stream()));
|
||||
}
|
||||
} else if (arg == 1) {
|
||||
// (M X K).T * M X N -> K X N
|
||||
auto base = zeros_like(primals[1], stream());
|
||||
auto at = swapaxes(primals[0], -1, -2, stream());
|
||||
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, K, N}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (K, N)
|
||||
auto g =
|
||||
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
auto shape = b.shape();
|
||||
shape.pop_back();
|
||||
shape.pop_back();
|
||||
vjps.push_back(swapaxes(
|
||||
gather_mm_grad(
|
||||
a,
|
||||
cotan,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
sorted,
|
||||
std::move(shape),
|
||||
stream()),
|
||||
-1,
|
||||
-2,
|
||||
stream()));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[GatherMM] Cannot calculate VJP with respect to indices.");
|
||||
|
@ -526,6 +526,16 @@ class GatherMM : public UnaryPrimitive {
|
||||
bool right_sorted_;
|
||||
};
|
||||
|
||||
class SegmentedMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_PRINT(SegmentedMM)
|
||||
};
|
||||
|
||||
class BroadcastAxes : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
|
||||
|
@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) {
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"segmented_mm",
|
||||
&mx::segmented_mm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"segments"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform a matrix multiplication but segment the inner dimension and
|
||||
save the result for each segment separately.
|
||||
|
||||
Args:
|
||||
a (array): Input array of shape ``MxK``.
|
||||
b (array): Input array of shape ``KxN``.
|
||||
segments (array): The offsets into the inner dimension for each segment.
|
||||
|
||||
Returns:
|
||||
array: The result per segment of shape ``MxN``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tensordot",
|
||||
[](const mx::array& a,
|
||||
|
@ -8,6 +8,9 @@ cuda_skip = {
|
||||
# Gather matmul NYI
|
||||
"TestBlas.test_gather_matmul",
|
||||
"TestBlas.test_gather_matmul_grad",
|
||||
"TestBlas.test_gather_mm_sorted",
|
||||
# Segmented matmul NYI
|
||||
"TestBlas.test_segmented_mm",
|
||||
# Scan NYI
|
||||
"TestArray.test_api",
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
@ -76,6 +79,7 @@ cuda_skip = {
|
||||
"TestQuantized.test_gather_matmul_grad",
|
||||
"TestQuantized.test_gather_qmm",
|
||||
"TestQuantized.test_gather_qmm_sorted",
|
||||
"TestQuantized.test_gather_qmm_grad",
|
||||
"TestQuantized.test_non_multiples",
|
||||
"TestQuantized.test_qmm",
|
||||
"TestQuantized.test_qmm_jvp",
|
||||
|
@ -1163,6 +1163,99 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
def test_gather_mm_sorted(self):
|
||||
def gather_mm_ref(a, b, rhs):
|
||||
b = b[rhs]
|
||||
return a @ b
|
||||
|
||||
def gather_mm_test(a, b, rhs):
|
||||
return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)
|
||||
|
||||
a = mx.random.normal((100, 1, 100))
|
||||
b = mx.random.normal((8, 100, 100))
|
||||
rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))
|
||||
|
||||
c1 = gather_mm_ref(a, b, rhs)
|
||||
c2 = gather_mm_test(a, b, rhs)
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
cotan = mx.random.normal(c1.shape)
|
||||
c1, dc1 = mx.vjp(
|
||||
lambda a, b: gather_mm_ref(a, b, rhs),
|
||||
[a, b],
|
||||
[cotan],
|
||||
)
|
||||
c2, dc2 = mx.vjp(
|
||||
lambda a, b: gather_mm_test(a, b, rhs),
|
||||
[a, b],
|
||||
[cotan],
|
||||
)
|
||||
self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4))
|
||||
self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4))
|
||||
self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4))
|
||||
|
||||
def test_segmented_mm(self):
|
||||
def segmented_mm_ref(a, b, s):
|
||||
s = s.tolist()
|
||||
c = []
|
||||
for s1, s2 in s:
|
||||
c.append(a[:, s1:s2] @ b[s1:s2, :])
|
||||
return mx.stack(c, axis=0)
|
||||
|
||||
shapes = [
|
||||
(10, 10, 10),
|
||||
(10, 10, 1000),
|
||||
(1000, 1000, 1000),
|
||||
]
|
||||
all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
|
||||
|
||||
for M, N, K in shapes:
|
||||
for s in all_segments:
|
||||
segments = []
|
||||
for i in range(len(s) - 1):
|
||||
segments.append([s[i], s[i + 1]])
|
||||
segments = mx.array(segments)
|
||||
segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))
|
||||
a = mx.random.normal((M, K))
|
||||
b = mx.random.normal((K, N))
|
||||
c1 = segmented_mm_ref(a, b, segments)
|
||||
c2 = mx.segmented_mm(a, b, segments)
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
a = mx.random.normal((K, M))
|
||||
b = mx.random.normal((K, N))
|
||||
c1 = segmented_mm_ref(a.T, b, segments)
|
||||
c2 = mx.segmented_mm(a.T, b, segments)
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
a = mx.random.normal((M, K))
|
||||
b = mx.random.normal((N, K))
|
||||
c1 = segmented_mm_ref(a, b.T, segments)
|
||||
c2 = mx.segmented_mm(a, b.T, segments)
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
a = mx.random.normal((K, M))
|
||||
b = mx.random.normal((N, K))
|
||||
c1 = segmented_mm_ref(a.T, b.T, segments)
|
||||
c2 = mx.segmented_mm(a.T, b.T, segments)
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.ones((2, 10, 10))
|
||||
s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32)
|
||||
mx.segmented_mm(a, a, s)
|
||||
|
||||
a = mx.ones((10, 1000))
|
||||
s = mx.random.randint(0, 16, shape=(1000,))
|
||||
s = mx.zeros(16, dtype=s.dtype).at[s].add(1)
|
||||
s = mx.sort(s)
|
||||
s = mx.cumsum(s)
|
||||
s = mx.concatenate([mx.array([0]), s])
|
||||
s = mx.as_strided(s, (16, 2), (1, 1))
|
||||
s = mx.reshape(s, (2, 2, 4, 2))
|
||||
c = mx.segmented_mm(a, a.T, s)
|
||||
self.assertEqual(c.shape, (2, 2, 4, 10, 10))
|
||||
|
||||
def test_gemv_gemm_same_precision(self):
|
||||
mx.random.seed(0)
|
||||
N = 256
|
||||
|
@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||
|
||||
def test_gather_qmm_grad(self):
|
||||
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
|
||||
if lhs is not None:
|
||||
x = x[lhs]
|
||||
if rhs is not None:
|
||||
w = w[rhs]
|
||||
s = s[rhs]
|
||||
b = b[rhs]
|
||||
return mx.quantized_matmul(x, w, s, b, transpose=trans)
|
||||
|
||||
def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
|
||||
return mx.gather_qmm(
|
||||
x,
|
||||
w,
|
||||
s,
|
||||
b,
|
||||
transpose=trans,
|
||||
lhs_indices=lhs,
|
||||
rhs_indices=rhs,
|
||||
sorted_indices=sort,
|
||||
)
|
||||
|
||||
x = mx.random.normal((16, 1, 256))
|
||||
w, s, b = mx.quantize(mx.random.normal((4, 256, 256)))
|
||||
indices = mx.sort(mx.random.randint(0, 4, shape=(16,)))
|
||||
cotan = mx.random.normal((16, 1, 256))
|
||||
|
||||
(o1,), (dx1, ds1, db1) = mx.vjp(
|
||||
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
|
||||
[x, s, b],
|
||||
[cotan],
|
||||
)
|
||||
(o2,), (dx2, ds2, db2) = mx.vjp(
|
||||
lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
|
||||
[x, s, b],
|
||||
[cotan],
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
|
||||
self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
|
||||
|
||||
def test_vjp_scales_biases(self):
|
||||
mx.random.seed(0)
|
||||
x = mx.random.normal(shape=(2, 2, 512))
|
||||
|
Loading…
Reference in New Issue
Block a user