mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -21,8 +21,8 @@ namespace {
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
@@ -30,8 +30,8 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
@@ -50,9 +50,9 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
@@ -60,9 +60,9 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
@@ -82,6 +82,25 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, array> check_transpose(
|
||||
std::vector<array>& copies,
|
||||
const Stream& s,
|
||||
const array& arr,
|
||||
bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -180,11 +199,11 @@ void steel_matmul_regular(
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<int> batch_shape,
|
||||
std::vector<size_t> batch_strides,
|
||||
size_t A_batch_stride,
|
||||
size_t B_batch_stride,
|
||||
size_t matrix_stride_out,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
std::vector<array>& copies) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
@@ -268,9 +287,9 @@ void steel_matmul_regular(
|
||||
/* const int ldd = */ ldd,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride,
|
||||
/* const size_t batch_stride_b = */ B_batch_stride,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ A_batch_stride,
|
||||
/* const int64_t batch_stride_b = */ B_batch_stride,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@@ -314,9 +333,9 @@ void steel_matmul(
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
std::vector<int> batch_shape /* = {} */,
|
||||
std::vector<size_t> A_batch_stride /* = {} */,
|
||||
std::vector<size_t> B_batch_stride /* = {} */) {
|
||||
Shape batch_shape /* = {} */,
|
||||
Strides A_batch_stride /* = {} */,
|
||||
Strides B_batch_stride /* = {} */) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
@@ -447,7 +466,7 @@ void steel_matmul(
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
auto batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
@@ -505,24 +524,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
@@ -662,9 +665,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* bool transpose_a = */ a_transposed,
|
||||
/* bool transpose_b = */ b_transposed,
|
||||
/* std::vector<array>& = */ copies,
|
||||
/* std::vector<int> batch_shape = */ batch_shape,
|
||||
/* std::vector<size_t> A_batch_stride = */ A_batch_stride,
|
||||
/* std::vector<size_t> B_batch_stride = */ B_batch_stride);
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -691,24 +694,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
array c = c_pre;
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
@@ -723,7 +710,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
|
||||
collapse_batches(a, b, c);
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * size_t(N);
|
||||
int64_t matrix_stride_out = M * static_cast<int64_t>(N);
|
||||
auto batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
@@ -1044,9 +1031,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const size_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const int64_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@@ -1054,7 +1041,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
GEMMAddMMParams params{
|
||||
/* const int ldc = */ ldc,
|
||||
/* const int fdc = */ fdc,
|
||||
/* const size_t batch_stride_c = */ C_batch_stride.back(),
|
||||
/* const int64_t batch_stride_c = */ C_batch_stride.back(),
|
||||
/* const float alpha = */ alpha_,
|
||||
/* const float beta = */ beta_};
|
||||
|
||||
@@ -1065,7 +1052,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
Strides batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
batch_strides.insert(
|
||||
@@ -1120,24 +1107,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
@@ -1156,20 +1127,20 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
std::vector<int> batch_shape{1};
|
||||
std::vector<size_t> A_batch_stride{0};
|
||||
std::vector<size_t> B_batch_stride{0};
|
||||
std::vector<size_t> outmask_bstride{0};
|
||||
std::vector<size_t> Amask_bstride{0};
|
||||
std::vector<size_t> Bmask_bstride{0};
|
||||
size_t A_batch_str = 0;
|
||||
size_t B_batch_str = 0;
|
||||
Shape batch_shape{1};
|
||||
Strides A_batch_stride{0};
|
||||
Strides B_batch_stride{0};
|
||||
Strides outmask_bstride{0};
|
||||
Strides Amask_bstride{0};
|
||||
Strides Bmask_bstride{0};
|
||||
int64_t A_batch_str = 0;
|
||||
int64_t B_batch_str = 0;
|
||||
|
||||
std::vector<size_t> batch_strides;
|
||||
Strides batch_strides;
|
||||
|
||||
if (out.ndim() > 2) {
|
||||
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
|
||||
std::vector<std::vector<size_t>> bstrides;
|
||||
Shape bshape{out.shape().begin(), out.shape().end() - 2};
|
||||
std::vector<Strides> bstrides;
|
||||
|
||||
for (auto& arr : inputs) {
|
||||
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
|
||||
@@ -1196,10 +1167,10 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
} else {
|
||||
batch_strides = std::vector<size_t>(inputs.size(), 0);
|
||||
batch_strides = Strides(inputs.size(), 0);
|
||||
}
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
int64_t matrix_stride_out = static_cast<int64_t>(M) * N;
|
||||
size_t batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1306,7 +1277,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Get mask params
|
||||
std::vector<int> mask_strides;
|
||||
std::vector<size_t> mask_batch_strides;
|
||||
Strides mask_batch_strides;
|
||||
if (has_out_mask) {
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
@@ -1436,9 +1407,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_str,
|
||||
/* const size_t batch_stride_b = */ B_batch_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ A_batch_str,
|
||||
/* const int64_t batch_stride_b = */ B_batch_str,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@@ -1524,24 +1495,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
@@ -1556,20 +1511,20 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
std::vector<size_t> batch_strides;
|
||||
Shape batch_shape = get_batch_dims(out.shape());
|
||||
Strides batch_strides;
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
rhs_indices.strides().begin(),
|
||||
rhs_indices.strides().end());
|
||||
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
@@ -1582,10 +1537,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int batch_ndim_B = b.ndim() - 2;
|
||||
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
|
||||
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
Shape batch_shape_A = get_batch_dims(a.shape());
|
||||
Strides batch_strides_A = get_batch_dims(a.strides());
|
||||
Shape batch_shape_B = get_batch_dims(b.shape());
|
||||
Strides batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
if (batch_ndim_A == 0) {
|
||||
batch_shape_A = {1};
|
||||
@@ -1597,7 +1552,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_strides_B = {0};
|
||||
}
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
auto matrix_stride_out = static_cast<int64_t>(M) * N;
|
||||
auto batch_size_out = out.size() / matrix_stride_out;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1801,9 +1756,9 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const size_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const int64_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ batch_ndim};
|
||||
|
||||
Reference in New Issue
Block a user