mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous * style * nits in docs * a bunch more stuff * update jit * update jit * use constant for shapes and strides and remove elem_to_loc overload * use kernel instantiation * docs nits * update binary and ternary * comments
This commit is contained in:
@@ -44,58 +44,26 @@ std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
template <typename stride_t>
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<stride_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (shape.size() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (const std::vector<stride_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!contiguous) {
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
to_collapse.push_back(i);
|
||||
}
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= shape[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<stride_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
const std::vector<std::vector<int64_t>>& strides,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>>& strides,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<array>& xs,
|
||||
size_t size_cap = std::numeric_limits<size_t>::max()) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
|
||||
}
|
||||
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
|
||||
Reference in New Issue
Block a user