mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Gemm update (#1518)
This commit is contained in:
@@ -88,6 +88,83 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Steel matmul fallback
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define GEMM_TPARAM_MACRO(devc) \
|
||||
if (devc == 'g') { /* Small device */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} else if (out.dtype() != float32) { /* half and bfloat */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} else if (devc == 'd') { /* Large device */ \
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \
|
||||
if (out.dtype() != float32) { /* half and bfloat */ \
|
||||
if (2 * std::max(M, N) > K) { /* Reasonable K */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} else if (!transpose_a && transpose_b) { /* nt with large k */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} else { /* nn with large K */ \
|
||||
bm = 32; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} /* float takes default */ \
|
||||
} else { /* smaller matmul */ \
|
||||
if (out.dtype() != float32) { /* half and bfloat */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} else { /* nn */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} else { /* floats */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 32; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} else { /* nn */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} else { /* Medium device */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
}
|
||||
|
||||
void steel_matmul_regular(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -112,19 +189,11 @@ void steel_matmul_regular(
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
@@ -903,19 +972,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Regular addmm dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
@@ -1667,19 +1728,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
|
||||
Reference in New Issue
Block a user