mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
ExpandDims primitive (#1687)
* add squeeze primitive * simplify squeeze, use in gather * fix * fix * fix * fix * fix no cpu * use squeeze in matmul and friends * expand dims primitive * comment
This commit is contained in:
@@ -144,23 +144,23 @@ array mlx_gather_nd(
|
||||
int slice_index = 0;
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
if (is_slice[i]) {
|
||||
std::vector<int> index_shape(max_dims + num_slices, 1);
|
||||
Shape index_shape(max_dims + num_slices, 1);
|
||||
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
||||
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||
slice_index++;
|
||||
} else {
|
||||
std::vector<int> index_shape = gather_indices[i].shape();
|
||||
auto index_shape = gather_indices[i].shape();
|
||||
index_shape.insert(index_shape.end(), num_slices, 1);
|
||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
||||
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// reshape them so that the int/array indices are last
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
if (i < num_slices) {
|
||||
std::vector<int> index_shape(max_dims + num_slices, 1);
|
||||
Shape index_shape(max_dims + num_slices, 1);
|
||||
index_shape[i] = gather_indices[i].shape(0);
|
||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
||||
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -172,19 +172,11 @@ array mlx_gather_nd(
|
||||
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
|
||||
src = gather(src, gather_indices, axes, slice_sizes);
|
||||
|
||||
// Squeeze the dims
|
||||
std::vector<int> out_shape;
|
||||
out_shape.insert(
|
||||
out_shape.end(),
|
||||
src.shape().begin(),
|
||||
src.shape().begin() + max_dims + num_slices);
|
||||
out_shape.insert(
|
||||
out_shape.end(),
|
||||
src.shape().begin() + max_dims + num_slices + indices.size(),
|
||||
src.shape().end());
|
||||
src = reshape(src, out_shape);
|
||||
|
||||
return src;
|
||||
// Squeeze the array index dims
|
||||
for (auto& ax : axes) {
|
||||
ax += max_dims + num_slices;
|
||||
}
|
||||
return squeeze(src, axes);
|
||||
}
|
||||
|
||||
auto mlx_expand_ellipsis(
|
||||
|
||||
Reference in New Issue
Block a user