allow take to work with integer index (#1440)

This commit is contained in:
Awni Hannun 2024-09-26 15:58:03 -07:00 committed by GitHub
parent 5b6f38df2b
commit 718aea3f1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 74 additions and 17 deletions

View File

@ -504,7 +504,20 @@ array squeeze(
shape.push_back(a.shape(i));
}
}
return reshape(a, shape, s);
return reshape(a, std::move(shape), s);
}
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
int ax = axis < 0 ? axis + a.ndim() : axis;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.erase(shape.begin() + ax);
return reshape(a, std::move(shape), s);
}
array squeeze(const array& a, StreamOrDevice s /* = {} */) {
@ -657,10 +670,15 @@ array slice(
array slice(
const array& a,
const std::vector<int>& start,
const std::vector<int>& stop,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s /* = {} */) {
return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s));
return slice(
a,
std::move(start),
std::move(stop),
std::vector<int>(a.ndim(), 1),
to_stream(s));
}
/** Update a slice from the source array */
@ -2715,13 +2733,43 @@ array take(
// Squeeze the axis we take over
std::vector<int> out_shape = out.shape();
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
return reshape(out, out_shape, s);
return reshape(out, std::move(out_shape), s);
}
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
return take(reshape(a, {-1}, s), indices, 0, s);
}
array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
// Check for valid axis
if (axis + static_cast<int>(a.ndim()) < 0 ||
axis >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[take] Received invalid axis " << axis << " for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
// Check for valid take
if (a.size() == 0) {
throw std::invalid_argument(
"[take] Cannot do a non-empty take from an array with zero elements.");
}
// Handle negative axis
axis = axis < 0 ? a.ndim() + axis : axis;
std::vector<int> starts(a.ndim(), 0);
std::vector<int> stops = a.shape();
starts[axis] = index;
stops[axis] = index + 1;
return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s);
}
array take(const array& a, int index, StreamOrDevice s /* = {} */) {
return take(reshape(a, {-1}, s), index, 0, s);
}
array take_along_axis(
const array& a,
const array& indices,
@ -2764,7 +2812,7 @@ array take_along_axis(
// Squeeze out the slice shape
std::vector<int> out_shape(
out.shape().begin(), out.shape().begin() + a.ndim());
return reshape(out, out_shape, s);
return reshape(out, std::move(out_shape), s);
}
array put_along_axis(

View File

@ -144,9 +144,7 @@ array squeeze(
StreamOrDevice s = {});
/** Remove singleton dimensions at the given axis. */
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
return squeeze(a, std::vector<int>{axis}, s);
}
array squeeze(const array& a, int axis, StreamOrDevice s = {});
/** Remove all singleton dimensions. */
array squeeze(const array& a, StreamOrDevice s = {});
@ -171,8 +169,8 @@ array slice(
/** Slice an array with a stride of 1 in each dimension. */
array slice(
const array& a,
const std::vector<int>& start,
const std::vector<int>& stop,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s = {});
/** Update a slice from the source array */
@ -936,9 +934,11 @@ array take(
const array& indices,
int axis,
StreamOrDevice s = {});
array take(const array& a, int index, int axis, StreamOrDevice s = {});
/** Take array entries at the given indices treating the array as flattened. */
array take(const array& a, const array& indices, StreamOrDevice s = {});
array take(const array& a, int index, StreamOrDevice s = {});
/** Take array entries given indices along the axis */
array take_along_axis(

View File

@ -1398,13 +1398,15 @@ void init_ops(nb::module_& m) {
m.def(
"take",
[](const array& a,
const array& indices,
const std::variant<int, array>& indices,
const std::optional<int>& axis,
StreamOrDevice s) {
if (axis.has_value()) {
return take(a, indices, axis.value(), s);
if (auto pv = std::get_if<int>(&indices); pv) {
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
} else {
return take(a, indices, s);
auto indices_ = std::get<array>(indices);
return axis ? take(a, indices_, axis.value(), s)
: take(a, indices_, s);
}
},
nb::arg(),
@ -1413,7 +1415,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
"def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Take elements along an axis.
@ -1425,7 +1427,7 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array.
indices (array): Input array with integral type.
indices (int or array): Integer index or input array with integral type.
axis (int, optional): Axis along which to perform the take. If unspecified
the array is treated as a flattened 1-D vector.

View File

@ -1059,6 +1059,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
# Take with integer index
a = mx.arange(8).reshape(2, 4)
out = mx.take(a, 1, axis=0)
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6, 7])))
out = mx.take(a, 1, axis=1)
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
def test_take_along_axis(self):
a_np = np.arange(8).reshape(2, 2, 2)
a_mlx = mx.array(a_np)