Collection of refactors (#2274)

* Refactor gemv into a function

* Refactor splitk step 1

* Refactor split k axpby

* Rearrange steel_gemm_regular

* Redirect steel_gemm_regular

* Add axpby routing to steel_matmul_regular

* Refactor AddMM step 1

* Redirect steel_gemm

* Update addmm

* Comments and format

* Some cleanup

* Add architecture gen to device

* Update no copy condition in normalization to account for axis size 1
This commit is contained in:
Jagrit Digani
2025-06-13 10:44:56 -07:00
committed by GitHub
parent c8b4787e4e
commit fddb6933e1
7 changed files with 720 additions and 642 deletions

View File

@@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
s,
d,
/* a = */ in_unfolded,
/* b = */ wt_transpose,
/* c = */ out,
/* M = */ implicit_M,
/* N = */ implicit_N,
/* K = */ implicit_K,
/* batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups,
/* b_cols = */ implicit_K,
/* out_cols = */ implicit_N * groups,
/* a_transposed = */ false,
/* b_transposed = */ true,
/* batch_shape = */ {1},
/* batch_strides = */ {0},
/* A_batch_strides = */ size_t(implicit_K),
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
/* matrix_stride_out = */ size_t(implicit_N),
/*copies = */ copies);
/* const Stream& s = */ s,
/* Device& d = */ d,
/* const array& a = */ in_unfolded,
/* const array& b = */ wt_transpose,
/* array& c = */ out,
/* int M = */ implicit_M,
/* int N = */ implicit_N,
/* int K = */ implicit_K,
/* int batch_size_out = */ groups,
/* int lda = */ implicit_K * groups,
/* int ldb = */ implicit_K,
/* int ldd = */ implicit_N * groups,
/* bool transpose_a = */ false,
/* bool transpose_b = */ true,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ {1},
/* Strides batch_strides = */ {0},
/* int64_t A_batch_strides = */ int64_t(implicit_K),
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
}
void implicit_gemm_conv_2D_gpu(