ExpandDims primitive (#1687)

* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
This commit is contained in:
Awni Hannun
2024-12-10 16:39:07 -08:00
committed by GitHub
parent 310ad8d9db
commit f76a49e555
13 changed files with 373 additions and 184 deletions

View File

@@ -144,23 +144,23 @@ array mlx_gather_nd(
int slice_index = 0;
for (int i = 0; i < gather_indices.size(); i++) {
if (is_slice[i]) {
std::vector<int> index_shape(max_dims + num_slices, 1);
Shape index_shape(max_dims + num_slices, 1);
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], index_shape);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
slice_index++;
} else {
std::vector<int> index_shape = gather_indices[i].shape();
auto index_shape = gather_indices[i].shape();
index_shape.insert(index_shape.end(), num_slices, 1);
gather_indices[i] = reshape(gather_indices[i], index_shape);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
}
}
} else {
// reshape them so that the int/array indices are last
for (int i = 0; i < gather_indices.size(); i++) {
if (i < num_slices) {
std::vector<int> index_shape(max_dims + num_slices, 1);
Shape index_shape(max_dims + num_slices, 1);
index_shape[i] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], index_shape);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
}
}
}
@@ -172,19 +172,11 @@ array mlx_gather_nd(
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
src = gather(src, gather_indices, axes, slice_sizes);
// Squeeze the dims
std::vector<int> out_shape;
out_shape.insert(
out_shape.end(),
src.shape().begin(),
src.shape().begin() + max_dims + num_slices);
out_shape.insert(
out_shape.end(),
src.shape().begin() + max_dims + num_slices + indices.size(),
src.shape().end());
src = reshape(src, out_shape);
return src;
// Squeeze the array index dims
for (auto& ax : axes) {
ax += max_dims + num_slices;
}
return squeeze(src, axes);
}
auto mlx_expand_ellipsis(