mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add vmap to scatter (#1200)
* Add vmap to scatter * updates * vmap updates + a few more tests * bug fix --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
11
mlx/ops.cpp
11
mlx/ops.cpp
@@ -1103,6 +1103,9 @@ array moveaxis(
|
||||
};
|
||||
source = check_ax(source);
|
||||
destination = check_ax(destination);
|
||||
if (source == destination) {
|
||||
return a;
|
||||
}
|
||||
std::vector<int> reorder(a.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
reorder.erase(reorder.begin() + source);
|
||||
@@ -2715,9 +2718,8 @@ array scatter(
|
||||
if (updates.ndim() != (a.ndim() + idx_shape.size())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scatter] Updates with " << updates.ndim()
|
||||
<< " dimensions does not match the sum of the array and indices "
|
||||
"dimensions "
|
||||
<< a.ndim() + idx_shape.size() << ".";
|
||||
<< " dimensions does not match the sum of the array (" << a.ndim()
|
||||
<< ") and indices (" << idx_shape.size() << ") dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
for (int i = 0; i < idx_shape.size(); ++i) {
|
||||
@@ -2759,11 +2761,12 @@ array scatter(
|
||||
inputs.insert(inputs.begin(), a);
|
||||
// TODO promote or cast?
|
||||
inputs.push_back(astype(updates, a.dtype(), s));
|
||||
|
||||
return array(
|
||||
a.shape(),
|
||||
a.dtype(),
|
||||
std::make_shared<Scatter>(to_stream(s), mode, axes),
|
||||
inputs);
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
array scatter(
|
||||
|
||||
Reference in New Issue
Block a user