mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add vmap to scatter (#1200)
* Add vmap to scatter * updates * vmap updates + a few more tests * bug fix --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -563,8 +563,10 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
if (in_axes.size() != inputs.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The number of in axes must match the number of inputs.");
|
||||
std::stringstream ss;
|
||||
ss << "[vmap] The number of in axes (" << in_axes.size()
|
||||
<< ") must match the number of inputs (" << inputs.size() << ").";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
// Some error checking and get the vmap axis size
|
||||
@@ -620,8 +622,10 @@ std::vector<array> vmap_replace(
|
||||
const std::vector<int>& in_axes,
|
||||
const std::vector<int>& out_axes) {
|
||||
if (out_axes.size() != s_outputs.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The number of out axes must match the number of outputs.");
|
||||
std::stringstream msg;
|
||||
msg << "[vmap] The number of out axes (" << out_axes.size()
|
||||
<< ") must match the number of outputs (" << s_outputs.size() << ").";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
||||
|
||||
Reference in New Issue
Block a user