mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
allow take to work with integer index (#1440)
This commit is contained in:
60
mlx/ops.cpp
60
mlx/ops.cpp
@@ -504,7 +504,20 @@ array squeeze(
|
||||
shape.push_back(a.shape(i));
|
||||
}
|
||||
}
|
||||
return reshape(a, shape, s);
|
||||
return reshape(a, std::move(shape), s);
|
||||
}
|
||||
|
||||
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
||||
int ax = axis < 0 ? axis + a.ndim() : axis;
|
||||
if (ax < 0 || ax >= a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto shape = a.shape();
|
||||
shape.erase(shape.begin() + ax);
|
||||
return reshape(a, std::move(shape), s);
|
||||
}
|
||||
|
||||
array squeeze(const array& a, StreamOrDevice s /* = {} */) {
|
||||
@@ -657,10 +670,15 @@ array slice(
|
||||
|
||||
array slice(
|
||||
const array& a,
|
||||
const std::vector<int>& start,
|
||||
const std::vector<int>& stop,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s));
|
||||
return slice(
|
||||
a,
|
||||
std::move(start),
|
||||
std::move(stop),
|
||||
std::vector<int>(a.ndim(), 1),
|
||||
to_stream(s));
|
||||
}
|
||||
|
||||
/** Update a slice from the source array */
|
||||
@@ -2715,13 +2733,43 @@ array take(
|
||||
// Squeeze the axis we take over
|
||||
std::vector<int> out_shape = out.shape();
|
||||
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
|
||||
return reshape(out, out_shape, s);
|
||||
return reshape(out, std::move(out_shape), s);
|
||||
}
|
||||
|
||||
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
|
||||
return take(reshape(a, {-1}, s), indices, 0, s);
|
||||
}
|
||||
|
||||
array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
|
||||
// Check for valid axis
|
||||
if (axis + static_cast<int>(a.ndim()) < 0 ||
|
||||
axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[take] Received invalid axis " << axis << " for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Check for valid take
|
||||
if (a.size() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[take] Cannot do a non-empty take from an array with zero elements.");
|
||||
}
|
||||
|
||||
// Handle negative axis
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
std::vector<int> stops = a.shape();
|
||||
starts[axis] = index;
|
||||
stops[axis] = index + 1;
|
||||
return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s);
|
||||
}
|
||||
|
||||
array take(const array& a, int index, StreamOrDevice s /* = {} */) {
|
||||
return take(reshape(a, {-1}, s), index, 0, s);
|
||||
}
|
||||
|
||||
array take_along_axis(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
@@ -2764,7 +2812,7 @@ array take_along_axis(
|
||||
// Squeeze out the slice shape
|
||||
std::vector<int> out_shape(
|
||||
out.shape().begin(), out.shape().begin() + a.ndim());
|
||||
return reshape(out, out_shape, s);
|
||||
return reshape(out, std::move(out_shape), s);
|
||||
}
|
||||
|
||||
array put_along_axis(
|
||||
|
10
mlx/ops.h
10
mlx/ops.h
@@ -144,9 +144,7 @@ array squeeze(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Remove singleton dimensions at the given axis. */
|
||||
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
|
||||
return squeeze(a, std::vector<int>{axis}, s);
|
||||
}
|
||||
array squeeze(const array& a, int axis, StreamOrDevice s = {});
|
||||
|
||||
/** Remove all singleton dimensions. */
|
||||
array squeeze(const array& a, StreamOrDevice s = {});
|
||||
@@ -171,8 +169,8 @@ array slice(
|
||||
/** Slice an array with a stride of 1 in each dimension. */
|
||||
array slice(
|
||||
const array& a,
|
||||
const std::vector<int>& start,
|
||||
const std::vector<int>& stop,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array */
|
||||
@@ -936,9 +934,11 @@ array take(
|
||||
const array& indices,
|
||||
int axis,
|
||||
StreamOrDevice s = {});
|
||||
array take(const array& a, int index, int axis, StreamOrDevice s = {});
|
||||
|
||||
/** Take array entries at the given indices treating the array as flattened. */
|
||||
array take(const array& a, const array& indices, StreamOrDevice s = {});
|
||||
array take(const array& a, int index, StreamOrDevice s = {});
|
||||
|
||||
/** Take array entries given indices along the axis */
|
||||
array take_along_axis(
|
||||
|
Reference in New Issue
Block a user