mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides * Fix uniform buffer indices or copy kernel arguments * Update utils.h * Remove manual unrolling of elem to loc loop * GPU copy updated to handle negative strides * Add slice update primitive
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,11 +8,12 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
template <typename stride_t>
|
||||
inline stride_t elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
size_t loc = 0;
|
||||
const std::vector<stride_t>& strides) {
|
||||
stride_t loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
loc += q_and_r.rem * strides[i];
|
||||
@@ -34,10 +35,11 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
template <typename stride_t>
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>> strides) {
|
||||
const std::vector<std::vector<stride_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
@@ -45,7 +47,7 @@ collapse_contiguous_dims(
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (const std::vector<size_t>& st : strides) {
|
||||
for (const std::vector<stride_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
@@ -62,7 +64,7 @@ collapse_contiguous_dims(
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<size_t>> out_strides(strides.size());
|
||||
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
@@ -70,7 +72,7 @@ collapse_contiguous_dims(
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<size_t>& st = strides[j];
|
||||
const std::vector<stride_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
@@ -94,4 +96,27 @@ collapse_contiguous_dims(Arrays... xs) {
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<stride_t>& strides) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
bool is_col_contiguous = true;
|
||||
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user