Refactor AddMM step 1

This commit is contained in:
Jagrit Digani 2025-06-11 09:01:45 -07:00
parent 9dbaa35be3
commit fc2f6bc51c

View File

@ -1000,124 +1000,37 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Regular addmm dispatch // Regular addmm dispatch
// Determine dispatch kernel
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = (batch_shape.size() > 1);
const bool use_out_source = true;
const bool do_axpby = !(alpha_ == 1. && beta_ == 1.);
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams gemm_params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int64_t batch_stride_a = */ A_batch_stride.back(),
/* const int64_t batch_stride_b = */ B_batch_stride.back(),
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
GEMMAddMMParams params{
/* const int ldc = */ ldc,
/* const int fdc = */ fdc,
/* const int64_t batch_stride_c = */ C_batch_stride.back(),
/* const float alpha = */ alpha_,
/* const float beta = */ beta_};
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
Strides batch_strides = A_batch_stride; Strides batch_strides = A_batch_stride;
batch_strides.insert( batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
batch_strides.insert( batch_strides.insert(
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
// Launch kernel return steel_matmul_regular_axpby(
compute_encoder.set_input_array(a, 0); /* const Stream& s = */ s,
compute_encoder.set_input_array(b, 1); /* metal::Device& d = */ d,
compute_encoder.set_input_array(c, 2); /* const array& a = */ a,
compute_encoder.set_output_array(out, 3); /* const array& b = */ b,
/* const array& c = */ c,
compute_encoder.set_bytes(gemm_params, 4); /* array& out = */ out,
compute_encoder.set_bytes(params, 5); /* int M = */ M,
/* int N = */ N,
compute_encoder.set_vector_bytes(batch_shape, 6); /* int K = */ K,
compute_encoder.set_vector_bytes(batch_strides, 7); /* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); /* int ldb = */ ldb,
/* int ldd = */ ldd,
d.add_temporaries(std::move(copies), s.index); /* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ std::move(batch_shape),
/* Strides batch_strides = */ std::move(batch_strides),
/* int64_t A_batch_stride = */ A_batch_stride.back(),
/* int64_t B_batch_stride = */ B_batch_stride.back(),
/* int64_t matrix_stride_out = */ int64_t(M) * ldd,
/* int64_t C_batch_stride = */ C_batch_stride.back(),
/* float alpha = */ alpha_,
/* float beta = */ beta_);
} }
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) { void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {