mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 22:51:19 +08:00
MoE backward improvements (#2335)
This commit is contained in:
parent
a4fcc893cd
commit
4a9b29a875
@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/backend/cpu/gemm.h"
|
||||||
#include "mlx/backend/cpu/lapack.h"
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@ -52,6 +53,58 @@ inline void mask_matrix(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void segmented_mm(
|
||||||
|
const T* a,
|
||||||
|
const T* b,
|
||||||
|
const uint32_t* segments,
|
||||||
|
T* out,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed,
|
||||||
|
size_t lda,
|
||||||
|
size_t ldb,
|
||||||
|
const Shape& a_shape,
|
||||||
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides,
|
||||||
|
size_t num_segments,
|
||||||
|
const Shape& segments_shape,
|
||||||
|
const Strides& segments_strides) {
|
||||||
|
int ndim = a_shape.size();
|
||||||
|
Shape a_copy = a_shape;
|
||||||
|
Shape b_copy = b_shape;
|
||||||
|
int32_t M = a_copy[ndim - 2];
|
||||||
|
int32_t N = b_copy[ndim - 1];
|
||||||
|
for (int i = 0; i < num_segments; i++) {
|
||||||
|
uint32_t k_start =
|
||||||
|
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||||
|
uint32_t k_end =
|
||||||
|
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||||
|
if (k_end <= k_start) {
|
||||||
|
std::fill_n(out + i * M * N, M * N, T(0));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
a_copy[ndim - 1] = k_end - k_start;
|
||||||
|
b_copy[ndim - 2] = k_end - k_start;
|
||||||
|
matmul<T>(
|
||||||
|
a + k_start * a_strides[ndim - 1],
|
||||||
|
b + k_start * b_strides[ndim - 2],
|
||||||
|
out + i * M * N,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
N,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
1,
|
||||||
|
a_copy,
|
||||||
|
a_strides,
|
||||||
|
b_copy,
|
||||||
|
b_strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -437,4 +490,121 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
auto check_transpose = [&s, &encoder](const array& x) {
|
||||||
|
auto stx = x.strides()[x.ndim() - 2];
|
||||||
|
auto sty = x.strides()[x.ndim() - 1];
|
||||||
|
if (stx == x.shape(-1) && sty == 1) {
|
||||||
|
return std::make_tuple(false, stx, x);
|
||||||
|
} else if (stx == 1 && sty == x.shape(-2)) {
|
||||||
|
return std::make_tuple(true, sty, x);
|
||||||
|
} else {
|
||||||
|
array xc(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy(x, xc, CopyType::General, s);
|
||||||
|
encoder.add_temporary(xc);
|
||||||
|
int64_t stx = x.shape(-1);
|
||||||
|
return std::make_tuple(false, stx, xc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
|
||||||
|
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
|
||||||
|
auto& segments = inputs[2];
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(segments);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
segments = array::unsafe_weak_copy(segments),
|
||||||
|
out_ptr = out.data<void>(),
|
||||||
|
a_transposed = a_transposed,
|
||||||
|
b_transposed = b_transposed,
|
||||||
|
lda = lda,
|
||||||
|
ldb = ldb]() {
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case float64:
|
||||||
|
segmented_mm<double>(
|
||||||
|
a.data<double>(),
|
||||||
|
b.data<double>(),
|
||||||
|
segments.data<uint32_t>(),
|
||||||
|
static_cast<double*>(out_ptr),
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
a.shape(),
|
||||||
|
a.strides(),
|
||||||
|
b.shape(),
|
||||||
|
b.strides(),
|
||||||
|
segments.size() / 2,
|
||||||
|
segments.shape(),
|
||||||
|
segments.strides());
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
segmented_mm<float>(
|
||||||
|
a.data<float>(),
|
||||||
|
b.data<float>(),
|
||||||
|
segments.data<uint32_t>(),
|
||||||
|
static_cast<float*>(out_ptr),
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
a.shape(),
|
||||||
|
a.strides(),
|
||||||
|
b.shape(),
|
||||||
|
b.strides(),
|
||||||
|
segments.size() / 2,
|
||||||
|
segments.shape(),
|
||||||
|
segments.strides());
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
segmented_mm<float16_t>(
|
||||||
|
a.data<float16_t>(),
|
||||||
|
b.data<float16_t>(),
|
||||||
|
segments.data<uint32_t>(),
|
||||||
|
static_cast<float16_t*>(out_ptr),
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
a.shape(),
|
||||||
|
a.strides(),
|
||||||
|
b.shape(),
|
||||||
|
b.strides(),
|
||||||
|
segments.size() / 2,
|
||||||
|
segments.shape(),
|
||||||
|
segments.strides());
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
segmented_mm<bfloat16_t>(
|
||||||
|
a.data<bfloat16_t>(),
|
||||||
|
b.data<bfloat16_t>(),
|
||||||
|
segments.data<uint32_t>(),
|
||||||
|
static_cast<bfloat16_t*>(out_ptr),
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
a.shape(),
|
||||||
|
a.strides(),
|
||||||
|
b.shape(),
|
||||||
|
b.strides(),
|
||||||
|
segments.size() / 2,
|
||||||
|
segments.shape(),
|
||||||
|
segments.strides());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Segmented mm supports only real float types.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -83,6 +83,7 @@ NO_GPU_MULTI(LUF)
|
|||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
|
NO_GPU(SegmentedMM)
|
||||||
NO_GPU_MULTI(SVD)
|
NO_GPU_MULTI(SVD)
|
||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
NO_GPU(Cholesky)
|
NO_GPU(Cholesky)
|
||||||
|
@ -63,6 +63,7 @@ if(MLX_METAL_JIT)
|
|||||||
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||||
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
||||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||||
|
make_jit_source(steel/gemm/kernels/steel_gemm_segmented)
|
||||||
make_jit_source(
|
make_jit_source(
|
||||||
steel/conv/conv
|
steel/conv/conv
|
||||||
kernels/steel/utils.h
|
kernels/steel/utils.h
|
||||||
|
@ -575,9 +575,17 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Set source info
|
// Set source info
|
||||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
if (ndim > 1) {
|
||||||
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||||
|
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||||
|
} else {
|
||||||
|
// The following will be ignored in the kernel but we still have to set
|
||||||
|
// some value so that metal validation passes.
|
||||||
|
compute_encoder.set_vector_bytes(idx.shape(), 3);
|
||||||
|
compute_encoder.set_vector_bytes(upd.strides(), 4);
|
||||||
|
compute_encoder.set_vector_bytes(idx.strides(), 5);
|
||||||
|
}
|
||||||
compute_encoder.set_bytes(ndim - 1, 6);
|
compute_encoder.set_bytes(ndim - 1, 6);
|
||||||
compute_encoder.set_bytes(axis_, 7);
|
compute_encoder.set_bytes(axis_, 7);
|
||||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||||
|
@ -34,6 +34,7 @@ const char* steel_gemm_fused();
|
|||||||
const char* steel_gemm_masked();
|
const char* steel_gemm_masked();
|
||||||
const char* steel_gemm_splitk();
|
const char* steel_gemm_splitk();
|
||||||
const char* steel_gemm_gather();
|
const char* steel_gemm_gather();
|
||||||
|
const char* steel_gemm_segmented();
|
||||||
const char* conv();
|
const char* conv();
|
||||||
const char* steel_conv();
|
const char* steel_conv();
|
||||||
const char* steel_conv_general();
|
const char* steel_conv_general();
|
||||||
|
@ -652,6 +652,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
|||||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
|
std::string kernel_source;
|
||||||
|
concatenate(
|
||||||
|
kernel_source,
|
||||||
|
metal::utils(),
|
||||||
|
metal::gemm(),
|
||||||
|
metal::steel_gemm_segmented(),
|
||||||
|
get_template_definition(
|
||||||
|
lib_name,
|
||||||
|
"segmented_mm",
|
||||||
|
get_type_string(out.dtype()),
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b));
|
||||||
|
return kernel_source;
|
||||||
|
});
|
||||||
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
|||||||
int wn,
|
int wn,
|
||||||
bool rhs);
|
bool rhs);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
@ -71,6 +71,7 @@ set(STEEL_HEADERS
|
|||||||
steel/gemm/kernels/steel_gemm_fused.h
|
steel/gemm/kernels/steel_gemm_fused.h
|
||||||
steel/gemm/kernels/steel_gemm_gather.h
|
steel/gemm/kernels/steel_gemm_gather.h
|
||||||
steel/gemm/kernels/steel_gemm_masked.h
|
steel/gemm/kernels/steel_gemm_masked.h
|
||||||
|
steel/gemm/kernels/steel_gemm_segmented.h
|
||||||
steel/gemm/kernels/steel_gemm_splitk.h
|
steel/gemm/kernels/steel_gemm_splitk.h
|
||||||
steel/utils/type_traits.h
|
steel/utils/type_traits.h
|
||||||
steel/utils/integral_constant.h)
|
steel/utils/integral_constant.h)
|
||||||
@ -120,6 +121,7 @@ if(NOT MLX_METAL_JIT)
|
|||||||
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
|
||||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||||
|
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
|
||||||
build_kernel(gemv_masked steel/utils.h)
|
build_kernel(gemv_masked steel/utils.h)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -0,0 +1,266 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
constant bool segments_contiguous [[function_constant(199)]];
|
||||||
|
constant bool align_M [[function_constant(200)]];
|
||||||
|
constant bool align_N [[function_constant(201)]];
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
typename AccumType = float>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
const device uint32_t* segments [[buffer(2)]],
|
||||||
|
device T* C [[buffer(3)]],
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
AccumType>;
|
||||||
|
|
||||||
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
|
|
||||||
|
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||||
|
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare threadgroup memory
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
// Find the block in A, B, C
|
||||||
|
const int c_row = tid.y * BM;
|
||||||
|
const int c_col = tid.x * BN;
|
||||||
|
const size_t c_row_long = size_t(c_row);
|
||||||
|
const size_t c_col_long = size_t(c_col);
|
||||||
|
|
||||||
|
// Prepare threadgroup bounds
|
||||||
|
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||||
|
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||||
|
|
||||||
|
// Move the pointers to the output tile
|
||||||
|
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||||
|
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||||
|
C += c_row_long * params->ldd + c_col_long;
|
||||||
|
|
||||||
|
// Move the pointers to the start of the segment
|
||||||
|
uint32_t k_start, k_end;
|
||||||
|
if (segments_contiguous) {
|
||||||
|
k_start = segments[2 * tid.z];
|
||||||
|
k_end = segments[2 * tid.z + 1];
|
||||||
|
} else {
|
||||||
|
// We accept either contiguous (above) or weird strides where the beginning
|
||||||
|
// of the next one is the previous one. Basically the last two strides are
|
||||||
|
// both 1!
|
||||||
|
k_start = segments[tid.z];
|
||||||
|
k_end = segments[tid.z + 1];
|
||||||
|
}
|
||||||
|
A += transpose_a ? k_start * params->lda : k_start;
|
||||||
|
B += transpose_b ? k_start : k_start * params->ldb;
|
||||||
|
C += tid.z * params->batch_stride_d;
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Matrix level alignment so only check K
|
||||||
|
if (align_M && align_N) {
|
||||||
|
uint32_t k = k_start + BK;
|
||||||
|
for (; k <= k_end; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
short k_remain = BK - short(k - k_end);
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
if (k_remain > 0) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
mma_op.store_result(C, params->ldd);
|
||||||
|
} else {
|
||||||
|
// Tile aligned do the same as above
|
||||||
|
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||||
|
uint32_t k = k_start + BK;
|
||||||
|
for (; k <= k_end; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
short k_remain = BK - short(k - k_end);
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
if (k_remain > 0) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
mma_op.store_result(C, params->ldd);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check rows
|
||||||
|
else if (align_N || tgp_bn == BN) {
|
||||||
|
uint32_t k = k_start + BK;
|
||||||
|
for (; k <= k_end; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_safe(
|
||||||
|
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
short k_remain = BK - short(k - k_end);
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
if (k_remain > 0) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check cols
|
||||||
|
else if (align_M || tgp_bm == BM) {
|
||||||
|
uint32_t k = k_start + BK;
|
||||||
|
for (; k <= k_end; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_safe(
|
||||||
|
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
short k_remain = BK - short(k - k_end);
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
if (k_remain > 0) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nothing aligned so check both rows and cols
|
||||||
|
else {
|
||||||
|
uint32_t k = k_start + BK;
|
||||||
|
for (; k <= k_end; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_safe(
|
||||||
|
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
||||||
|
loader_b.load_safe(
|
||||||
|
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
short k_remain = BK - short(k - k_end);
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
if (k_remain > 0) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,43 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h"
|
||||||
|
|
||||||
|
#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||||
|
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||||
|
segmented_mm, \
|
||||||
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
float)
|
||||||
|
|
||||||
|
#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
|
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||||
|
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
|
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||||
|
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
instantiate_segmented_mm_shapes_helper(float16, half, float16, half);
|
||||||
|
instantiate_segmented_mm_shapes_helper(
|
||||||
|
bfloat16,
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16,
|
||||||
|
bfloat16_t);
|
||||||
|
instantiate_segmented_mm_shapes_helper(float32, float, float32, float);
|
@ -1864,4 +1864,166 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void segmented_mm(
|
||||||
|
const array& a_,
|
||||||
|
const array& b_,
|
||||||
|
const array& segments_,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
auto check_segments_layout = [&d, &s](const array& x) {
|
||||||
|
// Contiguous so return early
|
||||||
|
if (x.flags().row_contiguous) {
|
||||||
|
return std::make_tuple(true, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool rc = true;
|
||||||
|
for (int i = 0; i < x.ndim() - 2; i++) {
|
||||||
|
rc &=
|
||||||
|
(x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1);
|
||||||
|
}
|
||||||
|
rc &= x.strides(x.ndim() - 1) == 1;
|
||||||
|
if (x.ndim() > 1) {
|
||||||
|
rc &= x.strides(x.ndim() - 2) == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rc) {
|
||||||
|
return std::make_tuple(false, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
d.add_temporary(x_copy, s.index);
|
||||||
|
return std::make_tuple(true, x_copy);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Copy if needed
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
||||||
|
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
||||||
|
auto [segments_contiguous, segments] = check_segments_layout(segments_);
|
||||||
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
|
|
||||||
|
// Determine dispatch kernel
|
||||||
|
int bm = 64, bn = 64, bk = 16;
|
||||||
|
int wm = 2, wn = 2;
|
||||||
|
size_t batch_size_out = out.size() / M / N;
|
||||||
|
|
||||||
|
char devc = d.get_architecture().back();
|
||||||
|
GEMM_TPARAM_MACRO(devc)
|
||||||
|
|
||||||
|
const bool align_M = (M % bm) == 0;
|
||||||
|
const bool align_N = (N % bn) == 0;
|
||||||
|
|
||||||
|
// Define the kernel name
|
||||||
|
std::string base_name;
|
||||||
|
base_name.reserve(128);
|
||||||
|
concatenate(
|
||||||
|
base_name,
|
||||||
|
"steel_segmented_mm_",
|
||||||
|
transpose_a ? 't' : 'n',
|
||||||
|
transpose_b ? 't' : 'n',
|
||||||
|
"_",
|
||||||
|
type_to_name(a),
|
||||||
|
"_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bm",
|
||||||
|
bm,
|
||||||
|
"_bn",
|
||||||
|
bn,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn);
|
||||||
|
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
|
||||||
|
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||||
|
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||||
|
};
|
||||||
|
|
||||||
|
// And the kernel hash that includes the function constants
|
||||||
|
std::string hash_name;
|
||||||
|
hash_name.reserve(128);
|
||||||
|
concatenate(
|
||||||
|
hash_name,
|
||||||
|
base_name,
|
||||||
|
"_segments_contiguous_",
|
||||||
|
segments_contiguous ? 't' : 'n',
|
||||||
|
"_align_M_",
|
||||||
|
align_M ? 't' : 'n',
|
||||||
|
"_align_N_",
|
||||||
|
align_N ? 't' : 'n');
|
||||||
|
|
||||||
|
// Get and set the kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = get_steel_gemm_segmented_kernel(
|
||||||
|
d,
|
||||||
|
base_name,
|
||||||
|
hash_name,
|
||||||
|
func_consts,
|
||||||
|
out,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
// Prepare the matmul params
|
||||||
|
steel::GEMMParams params{
|
||||||
|
/* const int M = */ M,
|
||||||
|
/* const int N = */ N,
|
||||||
|
/* const int K = */ K,
|
||||||
|
/* const int lda = */ static_cast<int>(lda),
|
||||||
|
/* const int ldb = */ static_cast<int>(ldb),
|
||||||
|
/* const int ldd = */ N,
|
||||||
|
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||||
|
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||||
|
/* const int64_t batch_stride_a = */ 0,
|
||||||
|
/* const int64_t batch_stride_b = */ 0,
|
||||||
|
/* const int64_t batch_stride_d = */ M * N,
|
||||||
|
/* const int swizzle_log = */ 0,
|
||||||
|
/* const int gemm_k_iterations_aligned = */ 0,
|
||||||
|
/* const int batch_ndim = */ 0};
|
||||||
|
|
||||||
|
// Prepare the grid
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims =
|
||||||
|
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
|
||||||
|
|
||||||
|
// Launch kernel
|
||||||
|
compute_encoder.set_input_array(a, 0);
|
||||||
|
compute_encoder.set_input_array(b, 1);
|
||||||
|
compute_encoder.set_input_array(segments, 2);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
compute_encoder.set_bytes(params, 4);
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto& segments = inputs[2];
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
// Extract shapes from inputs.
|
||||||
|
int M = a.shape(-2);
|
||||||
|
int N = b.shape(-1);
|
||||||
|
int K = a.shape(-1);
|
||||||
|
|
||||||
|
segmented_mm(a, b, segments, out, M, N, K, d, s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -210,6 +210,22 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
|||||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array&,
|
||||||
|
bool,
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int) {
|
||||||
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
@ -105,6 +105,7 @@ NO_CPU(Scan)
|
|||||||
NO_CPU(Scatter)
|
NO_CPU(Scatter)
|
||||||
NO_CPU(ScatterAxis)
|
NO_CPU(ScatterAxis)
|
||||||
NO_CPU(Select)
|
NO_CPU(Select)
|
||||||
|
NO_CPU(SegmentedMM)
|
||||||
NO_CPU(Sigmoid)
|
NO_CPU(Sigmoid)
|
||||||
NO_CPU(Sign)
|
NO_CPU(Sign)
|
||||||
NO_CPU(Sin)
|
NO_CPU(Sin)
|
||||||
|
@ -121,6 +121,7 @@ NO_GPU(Scan)
|
|||||||
NO_GPU(Scatter)
|
NO_GPU(Scatter)
|
||||||
NO_GPU(ScatterAxis)
|
NO_GPU(ScatterAxis)
|
||||||
NO_GPU(Select)
|
NO_GPU(Select)
|
||||||
|
NO_GPU(SegmentedMM)
|
||||||
NO_GPU(Sigmoid)
|
NO_GPU(Sigmoid)
|
||||||
NO_GPU(Sign)
|
NO_GPU(Sign)
|
||||||
NO_GPU(Sin)
|
NO_GPU(Sin)
|
||||||
|
48
mlx/ops.cpp
48
mlx/ops.cpp
@ -4649,6 +4649,54 @@ array gather_mm(
|
|||||||
return axes.empty() ? out : squeeze(out, axes, s);
|
return axes.empty() ? out : squeeze(out, axes, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array segmented_mm(
|
||||||
|
array a,
|
||||||
|
array b,
|
||||||
|
array segments,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (a.ndim() != 2 || b.ndim() != 2) {
|
||||||
|
throw std::invalid_argument("[segmented_mm] Batched matmul not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (segments.ndim() < 1 || segments.shape().back() != 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[segmented_mm] The segments should have shape (..., 2) but "
|
||||||
|
<< segments.shape() << " was provided.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type promotion
|
||||||
|
auto out_type = result_type(a, b);
|
||||||
|
if (!issubdtype(out_type, floating)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[segmented_mm] Only real floating point types are supported but "
|
||||||
|
<< a.dtype() << " and " << b.dtype()
|
||||||
|
<< " were provided which results in " << out_type
|
||||||
|
<< ", which is not a real floating point type.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!issubdtype(segments.dtype(), integer)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[segmented_mm] Got segments with invalid dtype. Segments must be integral.");
|
||||||
|
}
|
||||||
|
|
||||||
|
a = astype(a, out_type, s);
|
||||||
|
b = astype(b, out_type, s);
|
||||||
|
segments = astype(segments, uint32, s);
|
||||||
|
|
||||||
|
Shape out_shape = segments.shape();
|
||||||
|
out_shape.pop_back();
|
||||||
|
out_shape.push_back(a.shape(0));
|
||||||
|
out_shape.push_back(b.shape(1));
|
||||||
|
|
||||||
|
return array(
|
||||||
|
std::move(out_shape),
|
||||||
|
out_type,
|
||||||
|
std::make_shared<SegmentedMM>(to_stream(s)),
|
||||||
|
{std::move(a), std::move(b), std::move(segments)});
|
||||||
|
}
|
||||||
|
|
||||||
array diagonal(
|
array diagonal(
|
||||||
const array& a,
|
const array& a,
|
||||||
int offset /* = 0 */,
|
int offset /* = 0 */,
|
||||||
|
@ -1406,6 +1406,12 @@ array gather_mm(
|
|||||||
bool sorted_indices = false,
|
bool sorted_indices = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute a matrix product but segment the inner dimension and write the
|
||||||
|
* result separately for each segment.
|
||||||
|
*/
|
||||||
|
array segmented_mm(array a, array b, array segments, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Extract a diagonal or construct a diagonal array */
|
/** Extract a diagonal or construct a diagonal array */
|
||||||
array diagonal(
|
array diagonal(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -109,6 +109,70 @@ std::tuple<array, array, array, int> vmap_ternary_op(
|
|||||||
return {a, b, c, to_ax};
|
return {a, b, c, to_ax};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate the gradient wrt to the weights of the following calculation
|
||||||
|
//
|
||||||
|
// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted)
|
||||||
|
//
|
||||||
|
// Note the transpose above. This function returns the gradient for w.T so if w
|
||||||
|
// was used instead then one needs to transpose the returned gradient.
|
||||||
|
//
|
||||||
|
// We define it as a separate function to reuse it for gather_mm and
|
||||||
|
// gather_qmm.
|
||||||
|
array gather_mm_grad(
|
||||||
|
const array& x,
|
||||||
|
const array& dy,
|
||||||
|
const array& lhs_indices,
|
||||||
|
const array& rhs_indices,
|
||||||
|
bool sorted,
|
||||||
|
Shape batch_shape,
|
||||||
|
const Stream& s) {
|
||||||
|
int M = x.shape(-2);
|
||||||
|
int K = x.shape(-1);
|
||||||
|
int N = dy.shape(-1);
|
||||||
|
int num_segments = std::accumulate(
|
||||||
|
batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int>());
|
||||||
|
batch_shape.push_back(N);
|
||||||
|
batch_shape.push_back(K);
|
||||||
|
|
||||||
|
// If the indices are sorted then it means that we can do the whole gradient
|
||||||
|
// computation via a segmented matmul. We just need to calculate the segments
|
||||||
|
// using the indices.
|
||||||
|
if (sorted) {
|
||||||
|
auto segments = zeros({num_segments}, uint32, s);
|
||||||
|
segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s);
|
||||||
|
segments = cumsum(segments, 0, false, true, s);
|
||||||
|
segments = concatenate({array({0}, {1}, uint32), segments}, 0, s);
|
||||||
|
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s);
|
||||||
|
|
||||||
|
return reshape(
|
||||||
|
segmented_mm(
|
||||||
|
swapaxes(flatten(dy, 0, -2, s), 0, 1, s),
|
||||||
|
flatten(x, 0, -2, s),
|
||||||
|
segments,
|
||||||
|
s),
|
||||||
|
std::move(batch_shape),
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise we need to gather matmul the dy and then scatter add it to the
|
||||||
|
// correct locations.
|
||||||
|
else {
|
||||||
|
// TODO: If the lhs indices wasn't provided, this is always a sorted matmul
|
||||||
|
// so we should add that check.
|
||||||
|
auto dw = gather_mm(
|
||||||
|
swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s);
|
||||||
|
return reshape(
|
||||||
|
scatter_add(
|
||||||
|
zeros({num_segments, N, K}, dw.dtype(), s),
|
||||||
|
rhs_indices,
|
||||||
|
expand_dims(dw, -3, s),
|
||||||
|
0,
|
||||||
|
s),
|
||||||
|
std::move(batch_shape),
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::vector<array> Primitive::jvp(
|
std::vector<array> Primitive::jvp(
|
||||||
@ -3169,8 +3233,9 @@ std::vector<array> QuantizedMatmul::vjp(
|
|||||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
||||||
} else {
|
} else {
|
||||||
if (!dsb) {
|
if (!dsb) {
|
||||||
auto fc = flatten(cotangents[0], 0, -2, stream());
|
int ndim = primals[1].ndim();
|
||||||
auto fx = flatten(primals[0], 0, -2, stream());
|
auto fc = flatten(cotangents[0], 0, -ndim, stream());
|
||||||
|
auto fx = flatten(primals[0], 0, -ndim, stream());
|
||||||
auto dw = transpose_
|
auto dw = transpose_
|
||||||
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
|
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
|
||||||
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
|
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
|
||||||
@ -3181,7 +3246,6 @@ std::vector<array> QuantizedMatmul::vjp(
|
|||||||
vjps.push_back(sum(*dsb, -1, false, stream()));
|
vjps.push_back(sum(*dsb, -1, false, stream()));
|
||||||
} else {
|
} else {
|
||||||
// scales
|
// scales
|
||||||
auto s = stream();
|
|
||||||
auto wq = dequantize(
|
auto wq = dequantize(
|
||||||
primals[1],
|
primals[1],
|
||||||
ones_like(primals[2], stream()),
|
ones_like(primals[2], stream()),
|
||||||
@ -3253,34 +3317,42 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
auto& lhs_indices = primals[4];
|
auto& lhs_indices = primals[4];
|
||||||
auto& rhs_indices = primals[5];
|
auto& rhs_indices = primals[5];
|
||||||
|
|
||||||
|
int M = cotan.shape(-2);
|
||||||
|
int N = cotan.shape(-1);
|
||||||
|
int K = x.shape(-1);
|
||||||
|
|
||||||
bool sorted = left_sorted_ || right_sorted_;
|
bool sorted = left_sorted_ || right_sorted_;
|
||||||
|
bool no_broadcast = rhs_indices.size() * M * K == x.size();
|
||||||
|
std::optional<array> dsb = std::nullopt;
|
||||||
|
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
// gradient wrt to x
|
// gradient wrt to x
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
vjps.push_back(reshape(
|
auto g = gather_qmm(
|
||||||
scatter_add(
|
cotan,
|
||||||
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
w,
|
||||||
lhs_indices,
|
scales,
|
||||||
expand_dims(
|
biases,
|
||||||
gather_qmm(
|
std::nullopt,
|
||||||
cotan,
|
rhs_indices,
|
||||||
w,
|
!transpose_,
|
||||||
scales,
|
group_size_,
|
||||||
biases,
|
bits_,
|
||||||
std::nullopt,
|
sorted,
|
||||||
rhs_indices,
|
stream());
|
||||||
!transpose_,
|
if (sorted && no_broadcast) {
|
||||||
group_size_,
|
vjps.push_back(g);
|
||||||
bits_,
|
} else {
|
||||||
sorted,
|
vjps.push_back(reshape(
|
||||||
stream()),
|
scatter_add(
|
||||||
-3,
|
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||||
stream()),
|
lhs_indices,
|
||||||
0,
|
expand_dims(g, -3, stream()),
|
||||||
stream()),
|
0,
|
||||||
x.shape(),
|
stream()),
|
||||||
stream()));
|
x.shape(),
|
||||||
|
stream()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// gradient wrt to the indices is undefined
|
// gradient wrt to the indices is undefined
|
||||||
@ -3290,9 +3362,49 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// gradient wrt to w_q, scales or biases
|
// gradient wrt to w_q, scales or biases
|
||||||
else {
|
else if (arg == 1) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"GatherQMM::vjp no gradient wrt the quantized matrix yet.");
|
"GatherQMM::vjp no gradient wrt the quantized weights.");
|
||||||
|
} else {
|
||||||
|
if (!dsb) {
|
||||||
|
auto shape = w.shape();
|
||||||
|
shape.pop_back();
|
||||||
|
shape.pop_back();
|
||||||
|
dsb = unflatten(
|
||||||
|
gather_mm_grad(
|
||||||
|
x,
|
||||||
|
cotan,
|
||||||
|
lhs_indices,
|
||||||
|
rhs_indices,
|
||||||
|
sorted,
|
||||||
|
std::move(shape),
|
||||||
|
stream()),
|
||||||
|
-1,
|
||||||
|
{-1, group_size_},
|
||||||
|
stream());
|
||||||
|
}
|
||||||
|
if (arg == 3) {
|
||||||
|
vjps.push_back(sum(*dsb, -1, false, stream()));
|
||||||
|
} else {
|
||||||
|
vjps.push_back(
|
||||||
|
sum(multiply(
|
||||||
|
*dsb,
|
||||||
|
unflatten(
|
||||||
|
dequantize(
|
||||||
|
w,
|
||||||
|
ones_like(scales, stream()),
|
||||||
|
zeros_like(biases, stream()),
|
||||||
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
stream()),
|
||||||
|
-1,
|
||||||
|
{-1, group_size_},
|
||||||
|
stream()),
|
||||||
|
stream()),
|
||||||
|
-1,
|
||||||
|
false,
|
||||||
|
stream()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return vjps;
|
return vjps;
|
||||||
@ -5064,6 +5176,8 @@ std::vector<array> GatherMM::vjp(
|
|||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
auto& cotan = cotangents[0];
|
auto& cotan = cotangents[0];
|
||||||
|
|
||||||
|
auto& a = primals[0];
|
||||||
|
auto& b = primals[1];
|
||||||
auto& lhs_indices = primals[2];
|
auto& lhs_indices = primals[2];
|
||||||
auto& rhs_indices = primals[3];
|
auto& rhs_indices = primals[3];
|
||||||
|
|
||||||
@ -5072,39 +5186,46 @@ std::vector<array> GatherMM::vjp(
|
|||||||
int K = primals[0].shape(-1);
|
int K = primals[0].shape(-1);
|
||||||
|
|
||||||
bool sorted = left_sorted_ || right_sorted_;
|
bool sorted = left_sorted_ || right_sorted_;
|
||||||
|
bool no_broadcast = rhs_indices.size() * M * K == primals[0].size();
|
||||||
|
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
// M X N * (K X N).T -> M X K
|
auto g = gather_mm(
|
||||||
auto base = zeros_like(primals[0], stream());
|
cotan,
|
||||||
auto bt = swapaxes(primals[1], -1, -2, stream());
|
swapaxes(b, -1, -2, stream()),
|
||||||
|
std::nullopt,
|
||||||
auto base_shape = base.shape();
|
rhs_indices,
|
||||||
base = reshape(base, {-1, M, K}, stream());
|
sorted,
|
||||||
|
stream());
|
||||||
// g : (out_batch_shape) + (M, K)
|
if (sorted && no_broadcast) {
|
||||||
auto g =
|
vjps.push_back(g);
|
||||||
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
|
} else {
|
||||||
g = expand_dims(g, -3, stream());
|
vjps.push_back(reshape(
|
||||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
scatter_add(
|
||||||
|
flatten(zeros_like(a, stream()), 0, -3, stream()),
|
||||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
lhs_indices,
|
||||||
|
expand_dims(g, -3, stream()),
|
||||||
|
0,
|
||||||
|
stream()),
|
||||||
|
a.shape(),
|
||||||
|
stream()));
|
||||||
|
}
|
||||||
} else if (arg == 1) {
|
} else if (arg == 1) {
|
||||||
// (M X K).T * M X N -> K X N
|
auto shape = b.shape();
|
||||||
auto base = zeros_like(primals[1], stream());
|
shape.pop_back();
|
||||||
auto at = swapaxes(primals[0], -1, -2, stream());
|
shape.pop_back();
|
||||||
|
vjps.push_back(swapaxes(
|
||||||
auto base_shape = base.shape();
|
gather_mm_grad(
|
||||||
base = reshape(base, {-1, K, N}, stream());
|
a,
|
||||||
|
cotan,
|
||||||
// g : (out_batch_shape) + (K, N)
|
lhs_indices,
|
||||||
auto g =
|
rhs_indices,
|
||||||
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
|
sorted,
|
||||||
g = expand_dims(g, -3, stream());
|
std::move(shape),
|
||||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
stream()),
|
||||||
|
-1,
|
||||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
-2,
|
||||||
|
stream()));
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[GatherMM] Cannot calculate VJP with respect to indices.");
|
"[GatherMM] Cannot calculate VJP with respect to indices.");
|
||||||
|
@ -526,6 +526,16 @@ class GatherMM : public UnaryPrimitive {
|
|||||||
bool right_sorted_;
|
bool right_sorted_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SegmentedMM : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(SegmentedMM)
|
||||||
|
};
|
||||||
|
|
||||||
class BroadcastAxes : public UnaryPrimitive {
|
class BroadcastAxes : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
|
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
|
||||||
|
@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) {
|
|||||||
array: The result of the multiplication of ``x`` with ``w``
|
array: The result of the multiplication of ``x`` with ``w``
|
||||||
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"segmented_mm",
|
||||||
|
&mx::segmented_mm,
|
||||||
|
nb::arg(),
|
||||||
|
nb::arg(),
|
||||||
|
"segments"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Perform a matrix multiplication but segment the inner dimension and
|
||||||
|
save the result for each segment separately.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array of shape ``MxK``.
|
||||||
|
b (array): Input array of shape ``KxN``.
|
||||||
|
segments (array): The offsets into the inner dimension for each segment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The result per segment of shape ``MxN``.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"tensordot",
|
"tensordot",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
|
@ -8,6 +8,9 @@ cuda_skip = {
|
|||||||
# Gather matmul NYI
|
# Gather matmul NYI
|
||||||
"TestBlas.test_gather_matmul",
|
"TestBlas.test_gather_matmul",
|
||||||
"TestBlas.test_gather_matmul_grad",
|
"TestBlas.test_gather_matmul_grad",
|
||||||
|
"TestBlas.test_gather_mm_sorted",
|
||||||
|
# Segmented matmul NYI
|
||||||
|
"TestBlas.test_segmented_mm",
|
||||||
# Scan NYI
|
# Scan NYI
|
||||||
"TestArray.test_api",
|
"TestArray.test_api",
|
||||||
"TestAutograd.test_cumprod_grad",
|
"TestAutograd.test_cumprod_grad",
|
||||||
@ -76,6 +79,7 @@ cuda_skip = {
|
|||||||
"TestQuantized.test_gather_matmul_grad",
|
"TestQuantized.test_gather_matmul_grad",
|
||||||
"TestQuantized.test_gather_qmm",
|
"TestQuantized.test_gather_qmm",
|
||||||
"TestQuantized.test_gather_qmm_sorted",
|
"TestQuantized.test_gather_qmm_sorted",
|
||||||
|
"TestQuantized.test_gather_qmm_grad",
|
||||||
"TestQuantized.test_non_multiples",
|
"TestQuantized.test_non_multiples",
|
||||||
"TestQuantized.test_qmm",
|
"TestQuantized.test_qmm",
|
||||||
"TestQuantized.test_qmm_jvp",
|
"TestQuantized.test_qmm_jvp",
|
||||||
|
@ -1163,6 +1163,99 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(r.shape, t.shape)
|
self.assertEqual(r.shape, t.shape)
|
||||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||||
|
|
||||||
|
def test_gather_mm_sorted(self):
|
||||||
|
def gather_mm_ref(a, b, rhs):
|
||||||
|
b = b[rhs]
|
||||||
|
return a @ b
|
||||||
|
|
||||||
|
def gather_mm_test(a, b, rhs):
|
||||||
|
return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)
|
||||||
|
|
||||||
|
a = mx.random.normal((100, 1, 100))
|
||||||
|
b = mx.random.normal((8, 100, 100))
|
||||||
|
rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))
|
||||||
|
|
||||||
|
c1 = gather_mm_ref(a, b, rhs)
|
||||||
|
c2 = gather_mm_test(a, b, rhs)
|
||||||
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
|
cotan = mx.random.normal(c1.shape)
|
||||||
|
c1, dc1 = mx.vjp(
|
||||||
|
lambda a, b: gather_mm_ref(a, b, rhs),
|
||||||
|
[a, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
c2, dc2 = mx.vjp(
|
||||||
|
lambda a, b: gather_mm_test(a, b, rhs),
|
||||||
|
[a, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4))
|
||||||
|
|
||||||
|
def test_segmented_mm(self):
|
||||||
|
def segmented_mm_ref(a, b, s):
|
||||||
|
s = s.tolist()
|
||||||
|
c = []
|
||||||
|
for s1, s2 in s:
|
||||||
|
c.append(a[:, s1:s2] @ b[s1:s2, :])
|
||||||
|
return mx.stack(c, axis=0)
|
||||||
|
|
||||||
|
shapes = [
|
||||||
|
(10, 10, 10),
|
||||||
|
(10, 10, 1000),
|
||||||
|
(1000, 1000, 1000),
|
||||||
|
]
|
||||||
|
all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
|
||||||
|
|
||||||
|
for M, N, K in shapes:
|
||||||
|
for s in all_segments:
|
||||||
|
segments = []
|
||||||
|
for i in range(len(s) - 1):
|
||||||
|
segments.append([s[i], s[i + 1]])
|
||||||
|
segments = mx.array(segments)
|
||||||
|
segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))
|
||||||
|
a = mx.random.normal((M, K))
|
||||||
|
b = mx.random.normal((K, N))
|
||||||
|
c1 = segmented_mm_ref(a, b, segments)
|
||||||
|
c2 = mx.segmented_mm(a, b, segments)
|
||||||
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
|
a = mx.random.normal((K, M))
|
||||||
|
b = mx.random.normal((K, N))
|
||||||
|
c1 = segmented_mm_ref(a.T, b, segments)
|
||||||
|
c2 = mx.segmented_mm(a.T, b, segments)
|
||||||
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
|
a = mx.random.normal((M, K))
|
||||||
|
b = mx.random.normal((N, K))
|
||||||
|
c1 = segmented_mm_ref(a, b.T, segments)
|
||||||
|
c2 = mx.segmented_mm(a, b.T, segments)
|
||||||
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
|
a = mx.random.normal((K, M))
|
||||||
|
b = mx.random.normal((N, K))
|
||||||
|
c1 = segmented_mm_ref(a.T, b.T, segments)
|
||||||
|
c2 = mx.segmented_mm(a.T, b.T, segments)
|
||||||
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a = mx.ones((2, 10, 10))
|
||||||
|
s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32)
|
||||||
|
mx.segmented_mm(a, a, s)
|
||||||
|
|
||||||
|
a = mx.ones((10, 1000))
|
||||||
|
s = mx.random.randint(0, 16, shape=(1000,))
|
||||||
|
s = mx.zeros(16, dtype=s.dtype).at[s].add(1)
|
||||||
|
s = mx.sort(s)
|
||||||
|
s = mx.cumsum(s)
|
||||||
|
s = mx.concatenate([mx.array([0]), s])
|
||||||
|
s = mx.as_strided(s, (16, 2), (1, 1))
|
||||||
|
s = mx.reshape(s, (2, 2, 4, 2))
|
||||||
|
c = mx.segmented_mm(a, a.T, s)
|
||||||
|
self.assertEqual(c.shape, (2, 2, 4, 10, 10))
|
||||||
|
|
||||||
def test_gemv_gemm_same_precision(self):
|
def test_gemv_gemm_same_precision(self):
|
||||||
mx.random.seed(0)
|
mx.random.seed(0)
|
||||||
N = 256
|
N = 256
|
||||||
|
@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||||
|
|
||||||
|
def test_gather_qmm_grad(self):
|
||||||
|
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
|
||||||
|
if lhs is not None:
|
||||||
|
x = x[lhs]
|
||||||
|
if rhs is not None:
|
||||||
|
w = w[rhs]
|
||||||
|
s = s[rhs]
|
||||||
|
b = b[rhs]
|
||||||
|
return mx.quantized_matmul(x, w, s, b, transpose=trans)
|
||||||
|
|
||||||
|
def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
|
||||||
|
return mx.gather_qmm(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
s,
|
||||||
|
b,
|
||||||
|
transpose=trans,
|
||||||
|
lhs_indices=lhs,
|
||||||
|
rhs_indices=rhs,
|
||||||
|
sorted_indices=sort,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = mx.random.normal((16, 1, 256))
|
||||||
|
w, s, b = mx.quantize(mx.random.normal((4, 256, 256)))
|
||||||
|
indices = mx.sort(mx.random.randint(0, 4, shape=(16,)))
|
||||||
|
cotan = mx.random.normal((16, 1, 256))
|
||||||
|
|
||||||
|
(o1,), (dx1, ds1, db1) = mx.vjp(
|
||||||
|
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
|
||||||
|
[x, s, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
(o2,), (dx2, ds2, db2) = mx.vjp(
|
||||||
|
lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
|
||||||
|
[x, s, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
|
||||||
|
self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
|
||||||
|
|
||||||
def test_vjp_scales_biases(self):
|
def test_vjp_scales_biases(self):
|
||||||
mx.random.seed(0)
|
mx.random.seed(0)
|
||||||
x = mx.random.normal(shape=(2, 2, 512))
|
x = mx.random.normal(shape=(2, 2, 512))
|
||||||
|
Loading…
Reference in New Issue
Block a user