Gemm update (#1518)

This commit is contained in:
Jagrit Digani
2024-10-30 19:30:28 -07:00
committed by GitHub
parent 884af42da2
commit 960e3f0f05
9 changed files with 701 additions and 196 deletions

View File

@@ -88,6 +88,83 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Steel matmul fallback
///////////////////////////////////////////////////////////////////////////////
#define GEMM_TPARAM_MACRO(devc) \
if (devc == 'g') { /* Small device */ \
if (!transpose_a && transpose_b) { /* nt */ \
bm = 64; \
bn = 32; \
bk = 32; \
wm = 2; \
wn = 2; \
} else if (out.dtype() != float32) { /* half and bfloat */ \
bm = 64; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} \
} else if (devc == 'd') { /* Large device */ \
if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \
if (out.dtype() != float32) { /* half and bfloat */ \
if (2 * std::max(M, N) > K) { /* Reasonable K */ \
bm = 64; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} else if (!transpose_a && transpose_b) { /* nt with large k */ \
bm = 64; \
bn = 32; \
bk = 32; \
wm = 2; \
wn = 2; \
} else { /* nn with large K */ \
bm = 32; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} \
} /* float takes default */ \
} else { /* smaller matmul */ \
if (out.dtype() != float32) { /* half and bfloat */ \
if (!transpose_a && transpose_b) { /* nt */ \
bm = 64; \
bn = 32; \
bk = 32; \
wm = 2; \
wn = 2; \
} else { /* nn */ \
bm = 64; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} \
} else { /* floats */ \
if (!transpose_a && transpose_b) { /* nt */ \
bm = 32; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} else { /* nn */ \
bm = 64; \
bn = 32; \
bk = 32; \
wm = 2; \
wn = 2; \
} \
} \
} \
} else { /* Medium device */ \
bm = 64; \
bn = 64; \
bk = 16; \
wm = 2; \
wn = 2; \
}
void steel_matmul_regular(
const Stream& s,
metal::Device& d,
@@ -112,19 +189,11 @@ void steel_matmul_regular(
using namespace mlx::steel;
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
@@ -903,19 +972,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Regular addmm dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
@@ -1667,19 +1728,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;