mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Rearrange steel_gemm_regular
This commit is contained in:
parent
c2f1c2a338
commit
96a7017442
@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
// Perform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||||
return steel_matmul_regular(
|
return steel_matmul_regular(
|
||||||
s,
|
/* const Stream& s = */ s,
|
||||||
d,
|
/* Device& d = */ d,
|
||||||
/* a = */ in_unfolded,
|
/* const array& a = */ in_unfolded,
|
||||||
/* b = */ wt_transpose,
|
/* const array& b = */ wt_transpose,
|
||||||
/* c = */ out,
|
/* array& c = */ out,
|
||||||
/* M = */ implicit_M,
|
/* int M = */ implicit_M,
|
||||||
/* N = */ implicit_N,
|
/* int N = */ implicit_N,
|
||||||
/* K = */ implicit_K,
|
/* int K = */ implicit_K,
|
||||||
/* batch_size_out = */ groups,
|
/* int batch_size_out = */ groups,
|
||||||
/* a_cols = */ implicit_K * groups,
|
/* int lda = */ implicit_K * groups,
|
||||||
/* b_cols = */ implicit_K,
|
/* int ldb = */ implicit_K,
|
||||||
/* out_cols = */ implicit_N * groups,
|
/* int ldd = */ implicit_N * groups,
|
||||||
/* a_transposed = */ false,
|
/* bool transpose_a = */ false,
|
||||||
/* b_transposed = */ true,
|
/* bool transpose_b = */ true,
|
||||||
/* batch_shape = */ {1},
|
/* std::vector<array>& copies = */ copies,
|
||||||
/* batch_strides = */ {0},
|
/* Shape batch_shape = */ {1},
|
||||||
/* A_batch_strides = */ size_t(implicit_K),
|
/* Strides batch_strides = */ {0},
|
||||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
/* int64_t A_batch_strides = */ int64_t(implicit_K),
|
||||||
/* matrix_stride_out = */ size_t(implicit_N),
|
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
|
||||||
/*copies = */ copies);
|
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
|
||||||
}
|
}
|
||||||
|
|
||||||
void implicit_gemm_conv_2D_gpu(
|
void implicit_gemm_conv_2D_gpu(
|
||||||
|
@ -179,12 +179,12 @@ void steel_matmul_regular(
|
|||||||
int ldd,
|
int ldd,
|
||||||
bool transpose_a,
|
bool transpose_a,
|
||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies,
|
||||||
Shape batch_shape,
|
Shape batch_shape,
|
||||||
Strides batch_strides,
|
Strides batch_strides,
|
||||||
int64_t A_batch_stride,
|
int64_t A_batch_stride,
|
||||||
int64_t B_batch_stride,
|
int64_t B_batch_stride,
|
||||||
int64_t matrix_stride_out,
|
int64_t matrix_stride_out) {
|
||||||
std::vector<array>& copies) {
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
// Determine dispatch kernel
|
// Determine dispatch kernel
|
||||||
@ -563,12 +563,12 @@ void steel_matmul(
|
|||||||
/* int ldd = */ N,
|
/* int ldd = */ N,
|
||||||
/* bool transpose_a = */ transpose_a,
|
/* bool transpose_a = */ transpose_a,
|
||||||
/* bool transpose_b = */ transpose_b,
|
/* bool transpose_b = */ transpose_b,
|
||||||
|
/* std::vector<array>& copies = */ copies,
|
||||||
/* Shape batch_shape = */ std::move(batch_shape),
|
/* Shape batch_shape = */ std::move(batch_shape),
|
||||||
/* Strides batch_strides = */ std::move(batch_strides),
|
/* Strides batch_strides = */ std::move(batch_strides),
|
||||||
/* int64_t A_batch_stride = */ A_batch_stride.back(),
|
/* int64_t A_batch_stride = */ A_batch_stride.back(),
|
||||||
/* int64_t B_batch_stride = */ B_batch_stride.back(),
|
/* int64_t B_batch_stride = */ B_batch_stride.back(),
|
||||||
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
/* int64_t matrix_stride_out = */ matrix_stride_out);
|
||||||
/* std::vector<array>& copies = */ copies);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool CHECK_AB = true>
|
template <bool CHECK_AB = true>
|
||||||
|
@ -21,12 +21,12 @@ void steel_matmul_regular(
|
|||||||
int ldd,
|
int ldd,
|
||||||
bool transpose_a,
|
bool transpose_a,
|
||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies,
|
||||||
Shape batch_shape,
|
Shape batch_shape,
|
||||||
Strides batch_strides,
|
Strides batch_strides,
|
||||||
int64_t A_batch_stride,
|
int64_t A_batch_stride,
|
||||||
int64_t B_batch_stride,
|
int64_t B_batch_stride,
|
||||||
int64_t matrix_stride_out,
|
int64_t matrix_stride_out);
|
||||||
std::vector<array>& copies);
|
|
||||||
|
|
||||||
void steel_matmul(
|
void steel_matmul(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
|
Loading…
Reference in New Issue
Block a user