mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Collection of refactors (#2274)
* Refactor gemv into a function * Refactor splitk step 1 * Refactor split k axpby * Rearrange steel_gemm_regular * Redirect steel_gemm_regular * Add axpby routing to steel_matmul_regular * Refactor AddMM step 1 * Redirect steel_gemm * Update addmm * Comments and format * Some cleanup * Add architecture gen to device * Update no copy condition in normalization to account for axis size 1
This commit is contained in:
@@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
return steel_matmul_regular(
|
||||
s,
|
||||
d,
|
||||
/* a = */ in_unfolded,
|
||||
/* b = */ wt_transpose,
|
||||
/* c = */ out,
|
||||
/* M = */ implicit_M,
|
||||
/* N = */ implicit_N,
|
||||
/* K = */ implicit_K,
|
||||
/* batch_size_out = */ groups,
|
||||
/* a_cols = */ implicit_K * groups,
|
||||
/* b_cols = */ implicit_K,
|
||||
/* out_cols = */ implicit_N * groups,
|
||||
/* a_transposed = */ false,
|
||||
/* b_transposed = */ true,
|
||||
/* batch_shape = */ {1},
|
||||
/* batch_strides = */ {0},
|
||||
/* A_batch_strides = */ size_t(implicit_K),
|
||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
||||
/* matrix_stride_out = */ size_t(implicit_N),
|
||||
/*copies = */ copies);
|
||||
/* const Stream& s = */ s,
|
||||
/* Device& d = */ d,
|
||||
/* const array& a = */ in_unfolded,
|
||||
/* const array& b = */ wt_transpose,
|
||||
/* array& c = */ out,
|
||||
/* int M = */ implicit_M,
|
||||
/* int N = */ implicit_N,
|
||||
/* int K = */ implicit_K,
|
||||
/* int batch_size_out = */ groups,
|
||||
/* int lda = */ implicit_K * groups,
|
||||
/* int ldb = */ implicit_K,
|
||||
/* int ldd = */ implicit_N * groups,
|
||||
/* bool transpose_a = */ false,
|
||||
/* bool transpose_b = */ true,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ {1},
|
||||
/* Strides batch_strides = */ {0},
|
||||
/* int64_t A_batch_strides = */ int64_t(implicit_K),
|
||||
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
|
||||
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
|
||||
Reference in New Issue
Block a user