mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
allow take to work with integer index (#1440)
This commit is contained in:
parent
5b6f38df2b
commit
718aea3f1d
60
mlx/ops.cpp
60
mlx/ops.cpp
@ -504,7 +504,20 @@ array squeeze(
|
||||
shape.push_back(a.shape(i));
|
||||
}
|
||||
}
|
||||
return reshape(a, shape, s);
|
||||
return reshape(a, std::move(shape), s);
|
||||
}
|
||||
|
||||
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
||||
int ax = axis < 0 ? axis + a.ndim() : axis;
|
||||
if (ax < 0 || ax >= a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto shape = a.shape();
|
||||
shape.erase(shape.begin() + ax);
|
||||
return reshape(a, std::move(shape), s);
|
||||
}
|
||||
|
||||
array squeeze(const array& a, StreamOrDevice s /* = {} */) {
|
||||
@ -657,10 +670,15 @@ array slice(
|
||||
|
||||
array slice(
|
||||
const array& a,
|
||||
const std::vector<int>& start,
|
||||
const std::vector<int>& stop,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s));
|
||||
return slice(
|
||||
a,
|
||||
std::move(start),
|
||||
std::move(stop),
|
||||
std::vector<int>(a.ndim(), 1),
|
||||
to_stream(s));
|
||||
}
|
||||
|
||||
/** Update a slice from the source array */
|
||||
@ -2715,13 +2733,43 @@ array take(
|
||||
// Squeeze the axis we take over
|
||||
std::vector<int> out_shape = out.shape();
|
||||
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
|
||||
return reshape(out, out_shape, s);
|
||||
return reshape(out, std::move(out_shape), s);
|
||||
}
|
||||
|
||||
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
|
||||
return take(reshape(a, {-1}, s), indices, 0, s);
|
||||
}
|
||||
|
||||
array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
|
||||
// Check for valid axis
|
||||
if (axis + static_cast<int>(a.ndim()) < 0 ||
|
||||
axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[take] Received invalid axis " << axis << " for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Check for valid take
|
||||
if (a.size() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[take] Cannot do a non-empty take from an array with zero elements.");
|
||||
}
|
||||
|
||||
// Handle negative axis
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
std::vector<int> stops = a.shape();
|
||||
starts[axis] = index;
|
||||
stops[axis] = index + 1;
|
||||
return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s);
|
||||
}
|
||||
|
||||
array take(const array& a, int index, StreamOrDevice s /* = {} */) {
|
||||
return take(reshape(a, {-1}, s), index, 0, s);
|
||||
}
|
||||
|
||||
array take_along_axis(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
@ -2764,7 +2812,7 @@ array take_along_axis(
|
||||
// Squeeze out the slice shape
|
||||
std::vector<int> out_shape(
|
||||
out.shape().begin(), out.shape().begin() + a.ndim());
|
||||
return reshape(out, out_shape, s);
|
||||
return reshape(out, std::move(out_shape), s);
|
||||
}
|
||||
|
||||
array put_along_axis(
|
||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -144,9 +144,7 @@ array squeeze(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Remove singleton dimensions at the given axis. */
|
||||
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
|
||||
return squeeze(a, std::vector<int>{axis}, s);
|
||||
}
|
||||
array squeeze(const array& a, int axis, StreamOrDevice s = {});
|
||||
|
||||
/** Remove all singleton dimensions. */
|
||||
array squeeze(const array& a, StreamOrDevice s = {});
|
||||
@ -171,8 +169,8 @@ array slice(
|
||||
/** Slice an array with a stride of 1 in each dimension. */
|
||||
array slice(
|
||||
const array& a,
|
||||
const std::vector<int>& start,
|
||||
const std::vector<int>& stop,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array */
|
||||
@ -936,9 +934,11 @@ array take(
|
||||
const array& indices,
|
||||
int axis,
|
||||
StreamOrDevice s = {});
|
||||
array take(const array& a, int index, int axis, StreamOrDevice s = {});
|
||||
|
||||
/** Take array entries at the given indices treating the array as flattened. */
|
||||
array take(const array& a, const array& indices, StreamOrDevice s = {});
|
||||
array take(const array& a, int index, StreamOrDevice s = {});
|
||||
|
||||
/** Take array entries given indices along the axis */
|
||||
array take_along_axis(
|
||||
|
@ -1398,13 +1398,15 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"take",
|
||||
[](const array& a,
|
||||
const array& indices,
|
||||
const std::variant<int, array>& indices,
|
||||
const std::optional<int>& axis,
|
||||
StreamOrDevice s) {
|
||||
if (axis.has_value()) {
|
||||
return take(a, indices, axis.value(), s);
|
||||
if (auto pv = std::get_if<int>(&indices); pv) {
|
||||
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
|
||||
} else {
|
||||
return take(a, indices, s);
|
||||
auto indices_ = std::get<array>(indices);
|
||||
return axis ? take(a, indices_, axis.value(), s)
|
||||
: take(a, indices_, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
@ -1413,7 +1415,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Take elements along an axis.
|
||||
|
||||
@ -1425,7 +1427,7 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
indices (array): Input array with integral type.
|
||||
indices (int or array): Integer index or input array with integral type.
|
||||
axis (int, optional): Axis along which to perform the take. If unspecified
|
||||
the array is treated as a flattened 1-D vector.
|
||||
|
||||
|
@ -1059,6 +1059,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
# Take with integer index
|
||||
a = mx.arange(8).reshape(2, 4)
|
||||
out = mx.take(a, 1, axis=0)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6, 7])))
|
||||
out = mx.take(a, 1, axis=1)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
|
||||
|
||||
def test_take_along_axis(self):
|
||||
a_np = np.arange(8).reshape(2, 2, 2)
|
||||
a_mlx = mx.array(a_np)
|
||||
|
Loading…
Reference in New Issue
Block a user