mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 19:26:42 +08:00
166 lines
4.9 KiB
C++
166 lines
4.9 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <vector>
|
|
|
|
#include "mlx/array.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
template <typename stride_t>
|
|
inline stride_t elem_to_loc(
|
|
int elem,
|
|
const std::vector<int>& shape,
|
|
const std::vector<stride_t>& strides) {
|
|
stride_t loc = 0;
|
|
for (int i = shape.size() - 1; i >= 0; --i) {
|
|
auto q_and_r = ldiv(elem, shape[i]);
|
|
loc += q_and_r.rem * strides[i];
|
|
elem = q_and_r.quot;
|
|
}
|
|
return loc;
|
|
}
|
|
|
|
inline size_t elem_to_loc(int elem, const array& a) {
|
|
if (a.flags().row_contiguous) {
|
|
return elem;
|
|
}
|
|
return elem_to_loc(elem, a.shape(), a.strides());
|
|
}
|
|
|
|
template <typename stride_t>
|
|
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
|
|
std::vector<stride_t> strides(shape.size(), 1);
|
|
for (int i = shape.size() - 1; i > 0; i--) {
|
|
strides[i - 1] = strides[i] * shape[i];
|
|
}
|
|
return strides;
|
|
}
|
|
|
|
// Collapse dims that are contiguous to possibly route to a better kernel
|
|
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
|
// should return {{2, 4}, {{1, 2}}}.
|
|
//
|
|
// When multiple arrays are passed they should all have the same shape. The
|
|
// collapsed axes are also the same so one shape is returned.
|
|
template <typename stride_t>
|
|
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
|
collapse_contiguous_dims(
|
|
const std::vector<int>& shape,
|
|
const std::vector<std::vector<stride_t>> strides) {
|
|
// Make a vector that has axes separated with -1. Collapse all axes between
|
|
// -1.
|
|
std::vector<int> to_collapse;
|
|
if (shape.size() > 0) {
|
|
to_collapse.push_back(0);
|
|
for (int i = 1; i < shape.size(); i++) {
|
|
bool contiguous = true;
|
|
for (const std::vector<stride_t>& st : strides) {
|
|
if (st[i] * shape[i] != st[i - 1]) {
|
|
contiguous = false;
|
|
}
|
|
if (!contiguous) {
|
|
break;
|
|
}
|
|
}
|
|
if (!contiguous) {
|
|
to_collapse.push_back(-1);
|
|
}
|
|
to_collapse.push_back(i);
|
|
}
|
|
to_collapse.push_back(-1);
|
|
}
|
|
|
|
std::vector<int> out_shape;
|
|
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
|
for (int i = 0; i < to_collapse.size(); i++) {
|
|
int current_shape = shape[to_collapse[i]];
|
|
while (to_collapse[++i] != -1) {
|
|
current_shape *= shape[to_collapse[i]];
|
|
}
|
|
out_shape.push_back(current_shape);
|
|
for (int j = 0; j < strides.size(); j++) {
|
|
const std::vector<stride_t>& st = strides[j];
|
|
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
|
}
|
|
}
|
|
|
|
return std::make_tuple(out_shape, out_strides);
|
|
}
|
|
|
|
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
|
collapse_contiguous_dims(const std::vector<array>& xs) {
|
|
std::vector<std::vector<size_t>> strides;
|
|
for (auto& x : xs) {
|
|
strides.emplace_back(x.strides());
|
|
}
|
|
return collapse_contiguous_dims(xs[0].shape(), strides);
|
|
}
|
|
|
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
|
return collapse_contiguous_dims(
|
|
std::vector<array>{std::forward<Arrays>(xs)...});
|
|
}
|
|
|
|
// The single array version of the above.
|
|
inline std::tuple<std::vector<int>, std::vector<size_t>>
|
|
collapse_contiguous_dims(
|
|
const std::vector<int>& shape,
|
|
const std::vector<size_t>& strides) {
|
|
std::vector<int> collapsed_shape;
|
|
std::vector<size_t> collapsed_strides;
|
|
|
|
if (shape.size() > 0) {
|
|
collapsed_shape.push_back(shape[0]);
|
|
collapsed_strides.push_back(strides[0]);
|
|
for (int i = 1; i < shape.size(); i++) {
|
|
if (strides[i] * shape[i] != collapsed_strides.back() ||
|
|
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
|
|
std::numeric_limits<int>::max()) {
|
|
collapsed_shape.push_back(shape[i]);
|
|
collapsed_strides.push_back(strides[i]);
|
|
} else {
|
|
collapsed_shape.back() *= shape[i];
|
|
collapsed_strides.back() = strides[i];
|
|
}
|
|
}
|
|
}
|
|
|
|
return std::make_tuple(collapsed_shape, collapsed_strides);
|
|
}
|
|
|
|
template <typename stride_t>
|
|
inline auto check_contiguity(
|
|
const std::vector<int>& shape,
|
|
const std::vector<stride_t>& strides) {
|
|
size_t no_broadcast_data_size = 1;
|
|
size_t f_stride = 1;
|
|
size_t b_stride = 1;
|
|
bool is_row_contiguous = true;
|
|
bool is_col_contiguous = true;
|
|
|
|
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
|
is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
|
is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
|
f_stride *= shape[i];
|
|
b_stride *= shape[ri];
|
|
if (strides[i] > 0) {
|
|
no_broadcast_data_size *= shape[i];
|
|
}
|
|
}
|
|
|
|
return std::make_tuple(
|
|
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
|
|
}
|
|
|
|
inline bool is_donatable(const array& in, const array& out) {
|
|
constexpr size_t donation_extra = 16384;
|
|
|
|
return in.is_donatable() && in.itemsize() == out.itemsize() &&
|
|
in.buffer_size() <= out.nbytes() + donation_extra;
|
|
}
|
|
|
|
} // namespace mlx::core
|