mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
142
mlx/ops.cpp
142
mlx/ops.cpp
@@ -68,7 +68,7 @@ array indices_or_default(
|
||||
Shape shape(x.shape().begin(), x.shape().end() - 2);
|
||||
int total =
|
||||
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||
return reshape(arange(total, uint32, s), shape, s);
|
||||
return reshape(arange(total, uint32, s), std::move(shape), s);
|
||||
}
|
||||
|
||||
std::pair<int, int> extract_quantized_matmul_dims(
|
||||
@@ -3080,28 +3080,64 @@ array take_along_axis(
|
||||
// Allow negative axis
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
std::vector<array> nd_indices;
|
||||
Shape index_shape(a.ndim(), 1);
|
||||
for (int i = 0; i < a.ndim(); ++i) {
|
||||
if (i == axis) {
|
||||
nd_indices.push_back(indices);
|
||||
} else {
|
||||
// Reshape so they can be broadcast
|
||||
index_shape[i] = a.shape(i);
|
||||
nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s));
|
||||
index_shape[i] = 1;
|
||||
}
|
||||
}
|
||||
std::vector<int> dims(a.ndim());
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
Shape slice_sizes(a.ndim(), 1);
|
||||
auto out = gather(a, nd_indices, dims, slice_sizes, s);
|
||||
// Broadcast indices and input ignoring the take axis
|
||||
auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s);
|
||||
|
||||
// Squeeze out the slice shape
|
||||
for (auto& d : dims) {
|
||||
d += a.ndim();
|
||||
auto out_shape = inputs[1].shape();
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<GatherAxis>(to_stream(s), axis),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
array scatter_axis(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
const array& values,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType mode,
|
||||
StreamOrDevice s) {
|
||||
std::string prefix =
|
||||
(mode == ScatterAxis::None) ? "[put_along_axis]" : "[scatter_add_axis]";
|
||||
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Received invalid axis " << " for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return squeeze(out, dims, s);
|
||||
|
||||
if (indices.ndim() != a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Indices of dimension " << indices.ndim()
|
||||
<< " does not match array of dimension " << a.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto upd = astype(values, a.dtype(), s);
|
||||
|
||||
// Squeeze leading singletons out of update
|
||||
if (upd.ndim() > indices.ndim()) {
|
||||
std::vector<int> sq_ax(upd.ndim() - indices.ndim());
|
||||
std::iota(sq_ax.begin(), sq_ax.end(), 0);
|
||||
upd = squeeze(upd, sq_ax, s);
|
||||
}
|
||||
|
||||
auto inputs = broadcast_arrays({indices, upd}, s);
|
||||
inputs.insert(inputs.begin(), a);
|
||||
|
||||
// Allow negative axis
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
// Broadcast src, indices, values while ignoring the take axis
|
||||
inputs = broadcast_arrays(inputs, {axis - int(a.ndim())}, s);
|
||||
|
||||
auto out_shape = inputs[0].shape();
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<ScatterAxis>(to_stream(s), mode, axis),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
array put_along_axis(
|
||||
@@ -3110,45 +3146,16 @@ array put_along_axis(
|
||||
const array& values,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[put_along_axis] Received invalid axis " << " for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return scatter_axis(a, indices, values, axis, ScatterAxis::None, s);
|
||||
}
|
||||
|
||||
if (indices.ndim() != a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[put_along_axis] Indices of dimension " << indices.ndim()
|
||||
<< " does not match array of dimension " << a.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Allow negative axis
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
std::vector<array> nd_indices;
|
||||
Shape index_shape(a.ndim(), 1);
|
||||
for (int i = 0; i < a.ndim(); ++i) {
|
||||
if (i == axis) {
|
||||
nd_indices.push_back(indices);
|
||||
} else {
|
||||
// Reshape so they can be broadcast
|
||||
index_shape[i] = a.shape(i);
|
||||
nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s));
|
||||
index_shape[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
auto update = astype(broadcast_to(values, indices.shape(), s), a.dtype(), s);
|
||||
{
|
||||
auto update_shape = update.shape();
|
||||
update_shape.resize(update_shape.size() + a.ndim(), 1);
|
||||
update = reshape(update, std::move(update_shape), s);
|
||||
}
|
||||
std::vector<int> dims(a.ndim());
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
return scatter(a, nd_indices, update, dims, s);
|
||||
array scatter_add_axis(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
const array& values,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return scatter_axis(a, indices, values, axis, ScatterAxis::Sum, s);
|
||||
}
|
||||
|
||||
/** Scatter updates to given indices */
|
||||
@@ -3157,8 +3164,8 @@ array scatter(
|
||||
const std::vector<array>& indices,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType mode /*= Scatter::ReduceType::None*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
Scatter::ReduceType mode,
|
||||
StreamOrDevice s) {
|
||||
// Checks that indices, dimensions, and slice_sizes are all valid
|
||||
if (indices.size() > a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
@@ -3962,6 +3969,19 @@ array gather_qmm(
|
||||
std::tie(lhs_indices, rhs_indices) =
|
||||
broadcast_arrays(lhs_indices, rhs_indices, s);
|
||||
|
||||
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[gather_qmm] Got lhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
if (!issubdtype(rhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[gather_qmm] Got rhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
lhs_indices = astype(lhs_indices, uint32, s);
|
||||
rhs_indices = astype(rhs_indices, uint32, s);
|
||||
|
||||
// Compute the full output shape
|
||||
auto out_shape = lhs_indices.shape();
|
||||
out_shape.push_back(x.shape(-2));
|
||||
|
||||
Reference in New Issue
Block a user