mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
No copy gems (#801)
* Enable collapsing batch dims in gemm * Update gemm to only make copies when neither of the last 2 axes are contiguous * Update addmm to support gemv shapes * Update addmm to support irregular batch strides * Update tests
This commit is contained in:
@@ -191,6 +191,70 @@ inline void mps_matmul(
|
||||
});
|
||||
}
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: "
|
||||
<< "A " << a.shape() << ", B " << b.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
A_batch_stride.push_back(0);
|
||||
B_batch_stride.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride);
|
||||
}
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: "
|
||||
<< "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape()
|
||||
<< ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
auto C_batch_stride = batch_strides[2];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
A_batch_stride.push_back(0);
|
||||
B_batch_stride.push_back(0);
|
||||
C_batch_stride.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -211,22 +275,33 @@ void steel_matmul(
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
std::vector<array>& copies,
|
||||
std::vector<int> batch_shape /* = {} */,
|
||||
std::vector<size_t> A_batch_stride /* = {} */,
|
||||
std::vector<size_t> B_batch_stride /* = {} */) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Coalesce (B, M, K) X (K, N) to (B*M, K) X (K, N)
|
||||
if (batch_size_out > 1 && !transpose_a &&
|
||||
a.data_size() == batch_size_out * M * K && b.size() == K * N) {
|
||||
M = M * batch_size_out;
|
||||
batch_size_out = 1;
|
||||
if (batch_shape.empty()) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
|
||||
|
||||
batch_shape = batch_shape_;
|
||||
A_batch_stride = A_bstride_;
|
||||
B_batch_stride = B_bstride_;
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
||||
B_batch_stride.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_size_out = 1;
|
||||
|
||||
A_batch_stride = {0};
|
||||
B_batch_stride = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
}
|
||||
|
||||
// Account for batch sizes and basic broadcasting
|
||||
int batch_size_a = a.data_size() / (M * K);
|
||||
int batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
|
||||
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -269,18 +344,18 @@ void steel_matmul(
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
GEMMSpiltKParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
split_k_partitions,
|
||||
split_k_partition_stride,
|
||||
split_k_partition_size,
|
||||
gemm_k_iterations};
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldc = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int split_k_partitions = */ split_k_partitions,
|
||||
/* const int split_k_partition_stride = */ split_k_partition_stride,
|
||||
/* const int split_k_partition_size = */ split_k_partition_size,
|
||||
/* const int gemm_k_iterations_aligned = */ gemm_k_iterations};
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||
@@ -364,19 +439,20 @@ void steel_matmul(
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
matrix_stride_a,
|
||||
matrix_stride_b,
|
||||
matrix_stride_out,
|
||||
swizzle_log,
|
||||
(K / bk)};
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
@@ -386,37 +462,25 @@ void steel_matmul(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||
(batch_size_a == batch_size_b ||
|
||||
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
} else { // Otherwise launch kernels with set offsets
|
||||
// Launch kernel
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||
|
||||
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
|
||||
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
|
||||
}
|
||||
}
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
@@ -453,9 +517,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
if (sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
} else if (stx == 1) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@@ -473,8 +537,25 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
||||
B_batch_stride.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_size_out = 1;
|
||||
|
||||
A_batch_stride = {0};
|
||||
B_batch_stride = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
@@ -491,20 +572,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
|
||||
int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows;
|
||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||
|
||||
int batch_size_vec = vec.data_size() / in_vector_len;
|
||||
int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len;
|
||||
int stride_mat = batch_strides_mat.back();
|
||||
int stride_vec = batch_strides_vec.back();
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel =
|
||||
(batch_size_out == std::max(batch_size_mat, batch_size_vec) &&
|
||||
(batch_size_mat == batch_size_vec ||
|
||||
std::min(batch_size_mat, batch_size_vec) == 1));
|
||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||
|
||||
int nc_dim = out.ndim() - 2;
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
@@ -540,10 +619,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
|
||||
if (!contiguous_kernel) {
|
||||
kname << "_nc";
|
||||
}
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby0";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -556,25 +632,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
|
||||
if (contiguous_kernel) {
|
||||
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
|
||||
} else {
|
||||
// In case of complex broadcasting, we consider the shape[:-2] and
|
||||
// strides [:-2] to determine the location of a batch
|
||||
// nc_dim = out.ndim() - 2
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(
|
||||
vec.strides().data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(
|
||||
mat.strides().data(), nc_dim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
@@ -606,20 +675,23 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_size_out,
|
||||
a_cols,
|
||||
b_cols,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
copies);
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ a_cols,
|
||||
/* int ldb = */ b_cols,
|
||||
/* bool transpose_a = */ a_transposed,
|
||||
/* bool transpose_b = */ b_transposed,
|
||||
/* std::vector<array>& = */ copies,
|
||||
/* std::vector<int> batch_shape = */ batch_shape,
|
||||
/* std::vector<size_t> A_batch_stride = */ A_batch_stride,
|
||||
/* std::vector<size_t> B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -645,9 +717,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
if (sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
} else if (stx == 1) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@@ -665,33 +737,151 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
array c = c_pre;
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
int fdc = c.strides()[c.ndim() - 1];
|
||||
int matrix_stride_c = c.ndim() <= 2 ? 0 : c.strides()[c.ndim() - 3];
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
int ldd = N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
|
||||
collapse_batches(a, b, c);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
||||
C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&
|
||||
B_batch_stride.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_size_out = 1;
|
||||
|
||||
A_batch_stride = {0};
|
||||
B_batch_stride = {0};
|
||||
C_batch_stride = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||
|
||||
int stride_mat = batch_strides_mat.back();
|
||||
int stride_vec = batch_strides_vec.back();
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
bn = 16;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby1";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 7);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(
|
||||
C_batch_stride.data(), batch_ndim * sizeof(size_t), 13);
|
||||
|
||||
int bias_stride = c.strides()[c.ndim() - 1];
|
||||
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Account for batch sizes and basic broadcasting
|
||||
int batch_size_a = a.data_size() / (M * K);
|
||||
int batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
|
||||
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||
int matrix_stride_out = M * N;
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Split K specialization
|
||||
|
||||
int _tm = M / 16;
|
||||
int _tn = N / 16;
|
||||
int _tk = K / 16;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Split K specialization
|
||||
|
||||
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
||||
int bm = M < 40 ? 16 : 32;
|
||||
int bn = N < 40 ? 16 : 32;
|
||||
@@ -817,25 +1007,29 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams gemm_params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
|
||||
GEMMAddMMParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
matrix_stride_a,
|
||||
matrix_stride_b,
|
||||
matrix_stride_c,
|
||||
matrix_stride_out,
|
||||
swizzle_log,
|
||||
(K / bk),
|
||||
alpha_,
|
||||
beta_,
|
||||
fdc};
|
||||
/* const int ldc = */ ldc,
|
||||
/* const int fdc = */ fdc,
|
||||
/* const int batch_stride_c = */ int(C_batch_stride.back()),
|
||||
/* const float alpha = */ alpha_,
|
||||
/* const float beta = */ beta_};
|
||||
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
@@ -844,40 +1038,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||
(batch_size_a == batch_size_b ||
|
||||
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
} else { // Otherwise launch kernels with set offsets
|
||||
// Launch kernel
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5);
|
||||
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
auto c_off = elem_to_loc(M * N * i, c.shape(), c.strides());
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto c_buf = static_cast<const MTL::Buffer*>(c.buffer().ptr());
|
||||
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||
|
||||
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
|
||||
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||
compute_encoder->setBuffer(c_buf, c_off * c.itemsize(), 2);
|
||||
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
|
||||
}
|
||||
}
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
|
||||
Reference in New Issue
Block a user