mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
No copy gems (#801)
* Enable collapsing batch dims in gemm * Update gemm to only make copies when neither of the last 2 axes are contiguous * Update addmm to support gemv shapes * Update addmm to support irregular batch strides * Update tests
This commit is contained in:
@@ -28,10 +28,12 @@ void explicit_gemm_conv_ND_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<N>& conv_params) {
|
||||
// Get gemm shapes
|
||||
int implicit_M = out.size() / conv_params.O;
|
||||
int implicit_K = wt.size() / conv_params.O;
|
||||
int implicit_N = conv_params.O;
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape = {
|
||||
static_cast<int>(out.size() / conv_params.O),
|
||||
static_cast<int>(wt.size() / conv_params.O)};
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
@@ -59,20 +61,29 @@ void explicit_gemm_conv_ND_gpu(
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
|
||||
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
|
||||
auto wt_flags = wt.flags();
|
||||
wt_flags.row_contiguous = false;
|
||||
wt_flags.col_contiguous = true;
|
||||
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies;
|
||||
std::vector<array> copies = {in_unfolded, wt_reshaped};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt,
|
||||
/*b = */ wt_reshaped,
|
||||
/*c = */ out,
|
||||
/*M = */ unfolded_shape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ unfolded_shape[1],
|
||||
/*M = */ implicit_M,
|
||||
/*N = */ implicit_N,
|
||||
/*K = */ implicit_K,
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ unfolded_shape[1],
|
||||
/*b_cols = */ unfolded_shape[1],
|
||||
/*a_cols = */ implicit_K,
|
||||
/*b_cols = */ implicit_K,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
|
||||
Reference in New Issue
Block a user