Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -21,8 +21,8 @@ namespace {
inline auto collapse_batches(const array& a, const array& b) {
// Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
@@ -30,8 +30,8 @@ inline auto collapse_batches(const array& a, const array& b) {
throw std::runtime_error(msg.str());
}
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
@@ -50,9 +50,9 @@ inline auto collapse_batches(const array& a, const array& b) {
inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
@@ -60,9 +60,9 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
throw std::runtime_error(msg.str());
}
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
@@ -82,6 +82,25 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
std::tuple<bool, int64_t, array> check_transpose(
std::vector<array>& copies,
const Stream& s,
const array& arr,
bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
};
} // namespace
///////////////////////////////////////////////////////////////////////////////
@@ -180,11 +199,11 @@ void steel_matmul_regular(
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<int> batch_shape,
std::vector<size_t> batch_strides,
size_t A_batch_stride,
size_t B_batch_stride,
size_t matrix_stride_out,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
std::vector<array>& copies) {
using namespace mlx::steel;
@@ -268,9 +287,9 @@ void steel_matmul_regular(
/* const int ldd = */ ldd,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ A_batch_stride,
/* const size_t batch_stride_b = */ B_batch_stride,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int64_t batch_stride_a = */ A_batch_stride,
/* const int64_t batch_stride_b = */ B_batch_stride,
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -314,9 +333,9 @@ void steel_matmul(
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
std::vector<int> batch_shape /* = {} */,
std::vector<size_t> A_batch_stride /* = {} */,
std::vector<size_t> B_batch_stride /* = {} */) {
Shape batch_shape /* = {} */,
Strides A_batch_stride /* = {} */,
Strides B_batch_stride /* = {} */) {
using namespace mlx::steel;
if (batch_shape.empty()) {
@@ -447,7 +466,7 @@ void steel_matmul(
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
std::vector<size_t> batch_strides = A_batch_stride;
auto batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
@@ -505,24 +524,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, a_cols, a] = check_transpose(a_pre, M == 1);
auto [b_transposed, b_cols, b] = check_transpose(b_pre, N == 1);
auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
@@ -662,9 +665,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/* bool transpose_a = */ a_transposed,
/* bool transpose_b = */ b_transposed,
/* std::vector<array>& = */ copies,
/* std::vector<int> batch_shape = */ batch_shape,
/* std::vector<size_t> A_batch_stride = */ A_batch_stride,
/* std::vector<size_t> B_batch_stride = */ B_batch_stride);
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -691,24 +694,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
array c = c_pre;
int ldc = c.strides()[c.ndim() - 2];
@@ -723,7 +710,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
collapse_batches(a, b, c);
size_t matrix_stride_out = size_t(M) * size_t(N);
int64_t matrix_stride_out = M * static_cast<int64_t>(N);
auto batch_size_out = out.size() / (matrix_stride_out);
// Collapse batches into M if needed
@@ -1044,9 +1031,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ A_batch_stride.back(),
/* const size_t batch_stride_b = */ B_batch_stride.back(),
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int64_t batch_stride_a = */ A_batch_stride.back(),
/* const int64_t batch_stride_b = */ B_batch_stride.back(),
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -1054,7 +1041,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
GEMMAddMMParams params{
/* const int ldc = */ ldc,
/* const int fdc = */ fdc,
/* const size_t batch_stride_c = */ C_batch_stride.back(),
/* const int64_t batch_stride_c = */ C_batch_stride.back(),
/* const float alpha = */ alpha_,
/* const float beta = */ beta_};
@@ -1065,7 +1052,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
std::vector<size_t> batch_strides = A_batch_stride;
Strides batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
batch_strides.insert(
@@ -1120,24 +1107,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
int lda = a_cols;
int ldb = b_cols;
@@ -1156,20 +1127,20 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return decltype(v){v.begin(), v.end() - 2};
};
std::vector<int> batch_shape{1};
std::vector<size_t> A_batch_stride{0};
std::vector<size_t> B_batch_stride{0};
std::vector<size_t> outmask_bstride{0};
std::vector<size_t> Amask_bstride{0};
std::vector<size_t> Bmask_bstride{0};
size_t A_batch_str = 0;
size_t B_batch_str = 0;
Shape batch_shape{1};
Strides A_batch_stride{0};
Strides B_batch_stride{0};
Strides outmask_bstride{0};
Strides Amask_bstride{0};
Strides Bmask_bstride{0};
int64_t A_batch_str = 0;
int64_t B_batch_str = 0;
std::vector<size_t> batch_strides;
Strides batch_strides;
if (out.ndim() > 2) {
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
std::vector<std::vector<size_t>> bstrides;
Shape bshape{out.shape().begin(), out.shape().end() - 2};
std::vector<Strides> bstrides;
for (auto& arr : inputs) {
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
@@ -1196,10 +1167,10 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
} else {
batch_strides = std::vector<size_t>(inputs.size(), 0);
batch_strides = Strides(inputs.size(), 0);
}
size_t matrix_stride_out = size_t(M) * N;
int64_t matrix_stride_out = static_cast<int64_t>(M) * N;
size_t batch_size_out = out.size() / (matrix_stride_out);
/////////////////////////////////////////////////////////////////////////////
@@ -1306,7 +1277,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Get mask params
std::vector<int> mask_strides;
std::vector<size_t> mask_batch_strides;
Strides mask_batch_strides;
if (has_out_mask) {
auto& out_mask = inputs[2];
@@ -1436,9 +1407,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ A_batch_str,
/* const size_t batch_stride_b = */ B_batch_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int64_t batch_stride_a = */ A_batch_str,
/* const int64_t batch_stride_b = */ B_batch_str,
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -1524,24 +1495,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
int lda = a_cols;
int ldb = b_cols;
@@ -1556,20 +1511,20 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
std::vector<size_t> batch_strides;
Shape batch_shape = get_batch_dims(out.shape());
Strides batch_strides;
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
batch_strides.insert(
batch_strides.end(),
rhs_indices.strides().begin(),
rhs_indices.strides().end());
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size();
@@ -1582,10 +1537,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int batch_ndim_B = b.ndim() - 2;
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
Shape batch_shape_A = get_batch_dims(a.shape());
Strides batch_strides_A = get_batch_dims(a.strides());
Shape batch_shape_B = get_batch_dims(b.shape());
Strides batch_strides_B = get_batch_dims(b.strides());
if (batch_ndim_A == 0) {
batch_shape_A = {1};
@@ -1597,7 +1552,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_strides_B = {0};
}
size_t matrix_stride_out = size_t(M) * N;
auto matrix_stride_out = static_cast<int64_t>(M) * N;
auto batch_size_out = out.size() / matrix_stride_out;
/////////////////////////////////////////////////////////////////////////////
@@ -1801,9 +1756,9 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ lhs_indices_str,
/* const size_t batch_stride_b = */ rhs_indices_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int64_t batch_stride_a = */ lhs_indices_str,
/* const int64_t batch_stride_b = */ rhs_indices_str,
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim};