mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Masked gemv (#1211)
This commit is contained in:
@@ -786,38 +786,47 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby0";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
@@ -826,7 +835,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
@@ -838,11 +847,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
@@ -910,15 +917,19 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@@ -929,12 +940,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
|
||||
array c = c_pre;
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
@@ -997,38 +1004,47 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby1";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
@@ -1037,7 +1053,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
@@ -1344,15 +1360,19 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@@ -1363,33 +1383,38 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||
|
||||
// Prepare kernel name
|
||||
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
||||
std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
std::vector<int> batch_shape{1};
|
||||
std::vector<size_t> A_batch_stride{0};
|
||||
std::vector<size_t> B_batch_stride{0};
|
||||
std::vector<size_t> outmask_bstride{0};
|
||||
std::vector<size_t> Amask_bstride{0};
|
||||
std::vector<size_t> Bmask_bstride{0};
|
||||
size_t A_batch_str = 0;
|
||||
size_t B_batch_str = 0;
|
||||
|
||||
std::vector<size_t> batch_strides;
|
||||
|
||||
if (out.ndim() > 2) {
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
|
||||
std::vector<std::vector<size_t>> bstrides;
|
||||
|
||||
@@ -1397,14 +1422,26 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
|
||||
}
|
||||
|
||||
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
|
||||
batch_shape = bshape_c;
|
||||
A_batch_str = bstrides_c[0].back();
|
||||
B_batch_str = bstrides_c[1].back();
|
||||
// auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
|
||||
batch_shape = bshape;
|
||||
A_batch_str = bstrides[0].back();
|
||||
B_batch_str = bstrides[1].back();
|
||||
|
||||
for (auto& bstr : bstrides_c) {
|
||||
for (auto& bstr : bstrides) {
|
||||
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
|
||||
}
|
||||
|
||||
A_batch_stride = bstrides[0];
|
||||
B_batch_stride = bstrides[1];
|
||||
|
||||
if (has_out_mask) {
|
||||
outmask_bstride = bstrides[2];
|
||||
}
|
||||
if (has_op_mask) {
|
||||
Amask_bstride = bstrides[has_out_mask + 2];
|
||||
Bmask_bstride = bstrides[has_out_mask + 3];
|
||||
}
|
||||
|
||||
} else {
|
||||
batch_strides = std::vector<size_t>(inputs.size(), 0);
|
||||
}
|
||||
@@ -1412,6 +1449,174 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
size_t batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||
|
||||
auto mask_bstrides_mat = is_b_matrix ? Bmask_bstride : Amask_bstride;
|
||||
auto mask_bstrides_vec = is_b_matrix ? Amask_bstride : Bmask_bstride;
|
||||
|
||||
auto mat_mask_idx = int(has_out_mask) + (is_b_matrix ? 3 : 2);
|
||||
auto vec_mask_idx = int(has_out_mask) + (is_b_matrix ? 2 : 3);
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
bm = 1;
|
||||
bn = (block_size_ == 64 && out_vector_len >= 2048) ? 4 : 2;
|
||||
tm = block_size_ == 32 ? 4 : 8;
|
||||
tn = 4;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t";
|
||||
|
||||
} else {
|
||||
if (block_size_ == 32) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
bm = 2;
|
||||
} else {
|
||||
sm = 2;
|
||||
sn = 16;
|
||||
bm = out_vector_len >= 512 ? 4 : 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv";
|
||||
}
|
||||
|
||||
kname << "_outmask_" << out_mask_nm;
|
||||
kname << "_opmask_" << op_mask_nm;
|
||||
kname << "_" << type_to_name(out);
|
||||
kname << "_bm" << bm << "_bn" << bn;
|
||||
kname << "_sm" << sm << "_sn" << sn;
|
||||
kname << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
// Get mask params
|
||||
std::vector<int> mask_strides;
|
||||
std::vector<size_t> mask_batch_strides;
|
||||
if (has_out_mask) {
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
if (transpose_mat) {
|
||||
mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -1 : -2));
|
||||
mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -2 : -1));
|
||||
} else {
|
||||
mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -1 : -2));
|
||||
mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -2 : -1));
|
||||
}
|
||||
|
||||
mask_batch_strides.insert(
|
||||
mask_batch_strides.end(),
|
||||
outmask_bstride.begin(),
|
||||
outmask_bstride.end());
|
||||
|
||||
compute_encoder.set_input_array(out_mask, 20);
|
||||
}
|
||||
|
||||
if (has_op_mask) {
|
||||
auto& mat_mask = inputs[mat_mask_idx];
|
||||
|
||||
if (transpose_mat) {
|
||||
mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -2 : -1));
|
||||
mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -1 : -2));
|
||||
} else {
|
||||
mask_strides.push_back(mat_mask.strides(is_b_matrix ? -2 : -1));
|
||||
mask_strides.push_back(mat_mask.strides(is_b_matrix ? -1 : -2));
|
||||
}
|
||||
|
||||
mask_batch_strides.insert(
|
||||
mask_batch_strides.end(),
|
||||
mask_bstrides_mat.begin(),
|
||||
mask_bstrides_mat.end());
|
||||
|
||||
compute_encoder.set_input_array(mat_mask, 21);
|
||||
|
||||
auto& vec_mask = inputs[vec_mask_idx];
|
||||
if (transpose_mat) {
|
||||
mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -1 : -2));
|
||||
mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -2 : -1));
|
||||
} else {
|
||||
mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -1 : -2));
|
||||
mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -2 : -1));
|
||||
}
|
||||
|
||||
mask_batch_strides.insert(
|
||||
mask_batch_strides.end(),
|
||||
mask_bstrides_vec.begin(),
|
||||
mask_bstrides_vec.end());
|
||||
|
||||
compute_encoder.set_input_array(vec_mask, 22);
|
||||
}
|
||||
|
||||
// Get gemv params
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
|
||||
set_vector_bytes(compute_encoder, mask_strides, 23);
|
||||
set_vector_bytes(compute_encoder, mask_batch_strides, 24);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
|
||||
@@ -1421,10 +1626,6 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
bool k_aligned = K % bk == 0;
|
||||
|
||||
// Prepare kernel name
|
||||
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
||||
std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_"
|
||||
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
|
||||
@@ -1554,15 +1755,19 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@@ -1573,16 +1778,12 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
@@ -1673,38 +1874,47 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
kname << "gemv_t_bs_" << type_to_name(out);
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_gather_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
kname << "gemv_bs_" << type_to_name(out);
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_gather_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -1712,7 +1922,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
|
||||
Reference in New Issue
Block a user