mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Refactor AddMM step 1
This commit is contained in:
parent
9dbaa35be3
commit
fc2f6bc51c
@ -1000,124 +1000,37 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Regular addmm dispatch
|
// Regular addmm dispatch
|
||||||
|
|
||||||
// Determine dispatch kernel
|
|
||||||
int bm = 64, bn = 64, bk = 16;
|
|
||||||
int wm = 2, wn = 2;
|
|
||||||
|
|
||||||
char devc = d.get_architecture().back();
|
|
||||||
GEMM_TPARAM_MACRO(devc)
|
|
||||||
|
|
||||||
// Prepare kernel name
|
|
||||||
std::ostringstream kname;
|
|
||||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
|
||||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
|
||||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
|
||||||
<< "_wm" << wm << "_wn" << wn;
|
|
||||||
|
|
||||||
std::string base_name = kname.str();
|
|
||||||
|
|
||||||
const bool has_batch = (batch_shape.size() > 1);
|
|
||||||
const bool use_out_source = true;
|
|
||||||
const bool do_axpby = !(alpha_ == 1. && beta_ == 1.);
|
|
||||||
const bool align_M = (M % bm) == 0;
|
|
||||||
const bool align_N = (N % bn) == 0;
|
|
||||||
const bool align_K = (K % bk) == 0;
|
|
||||||
|
|
||||||
metal::MTLFCList func_consts = {
|
|
||||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
|
||||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
|
||||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
|
||||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
|
||||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
|
||||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
|
||||||
};
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
|
||||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
|
||||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
|
||||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
|
||||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = get_steel_gemm_fused_kernel(
|
|
||||||
d,
|
|
||||||
base_name,
|
|
||||||
hash_name,
|
|
||||||
func_consts,
|
|
||||||
out,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
bm,
|
|
||||||
bn,
|
|
||||||
bk,
|
|
||||||
wm,
|
|
||||||
wn);
|
|
||||||
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
int tn = (N + bn - 1) / bn;
|
|
||||||
int tm = (M + bm - 1) / bm;
|
|
||||||
|
|
||||||
// TODO: Explore device-based tuning for swizzle
|
|
||||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
|
||||||
|
|
||||||
// Prepare steel matmul params
|
|
||||||
GEMMParams gemm_params{
|
|
||||||
/* const int M = */ M,
|
|
||||||
/* const int N = */ N,
|
|
||||||
/* const int K = */ K,
|
|
||||||
/* const int lda = */ lda,
|
|
||||||
/* const int ldb = */ ldb,
|
|
||||||
/* const int ldd = */ N,
|
|
||||||
/* const int tiles_n = */ tn,
|
|
||||||
/* const int tiles_m = */ tm,
|
|
||||||
/* const int64_t batch_stride_a = */ A_batch_stride.back(),
|
|
||||||
/* const int64_t batch_stride_b = */ B_batch_stride.back(),
|
|
||||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
|
||||||
/* const int swizzle_log = */ swizzle_log,
|
|
||||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
|
||||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
|
||||||
|
|
||||||
GEMMAddMMParams params{
|
|
||||||
/* const int ldc = */ ldc,
|
|
||||||
/* const int fdc = */ fdc,
|
|
||||||
/* const int64_t batch_stride_c = */ C_batch_stride.back(),
|
|
||||||
/* const float alpha = */ alpha_,
|
|
||||||
/* const float beta = */ beta_};
|
|
||||||
|
|
||||||
int tile = 1 << swizzle_log;
|
|
||||||
tm = (tm + tile - 1) / tile;
|
|
||||||
tn = tn * tile;
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
||||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
|
||||||
|
|
||||||
Strides batch_strides = A_batch_stride;
|
Strides batch_strides = A_batch_stride;
|
||||||
batch_strides.insert(
|
batch_strides.insert(
|
||||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||||
batch_strides.insert(
|
batch_strides.insert(
|
||||||
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
|
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
|
||||||
|
|
||||||
// Launch kernel
|
return steel_matmul_regular_axpby(
|
||||||
compute_encoder.set_input_array(a, 0);
|
/* const Stream& s = */ s,
|
||||||
compute_encoder.set_input_array(b, 1);
|
/* metal::Device& d = */ d,
|
||||||
compute_encoder.set_input_array(c, 2);
|
/* const array& a = */ a,
|
||||||
compute_encoder.set_output_array(out, 3);
|
/* const array& b = */ b,
|
||||||
|
/* const array& c = */ c,
|
||||||
compute_encoder.set_bytes(gemm_params, 4);
|
/* array& out = */ out,
|
||||||
compute_encoder.set_bytes(params, 5);
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
/* int K = */ K,
|
||||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
/* int batch_size_out = */ batch_size_out,
|
||||||
|
/* int lda = */ lda,
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
/* int ldb = */ ldb,
|
||||||
|
/* int ldd = */ ldd,
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
/* bool transpose_a = */ transpose_a,
|
||||||
|
/* bool transpose_b = */ transpose_b,
|
||||||
|
/* std::vector<array>& copies = */ copies,
|
||||||
|
/* Shape batch_shape = */ std::move(batch_shape),
|
||||||
|
/* Strides batch_strides = */ std::move(batch_strides),
|
||||||
|
/* int64_t A_batch_stride = */ A_batch_stride.back(),
|
||||||
|
/* int64_t B_batch_stride = */ B_batch_stride.back(),
|
||||||
|
/* int64_t matrix_stride_out = */ int64_t(M) * ldd,
|
||||||
|
/* int64_t C_batch_stride = */ C_batch_stride.back(),
|
||||||
|
/* float alpha = */ alpha_,
|
||||||
|
/* float beta = */ beta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
Loading…
Reference in New Issue
Block a user