mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Fix vmap constant output size (#1524)
* use inputs to determine output size * remove noop vmap tests
This commit is contained in:
@@ -686,6 +686,17 @@ std::vector<array> vmap_replace(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int vmap_size = -1;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (in_axes[i] >= 0) {
|
||||
vmap_size = inputs[i].shape(in_axes[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (vmap_size == -1) {
|
||||
throw std::invalid_argument("At least one of in_axes must be non-None.");
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
||||
std::unordered_set<std::uintptr_t> needs_vmap;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
@@ -782,7 +793,11 @@ std::vector<array> vmap_replace(
|
||||
}
|
||||
outputs.push_back(out);
|
||||
} else {
|
||||
outputs.push_back(s_outputs[i]);
|
||||
// When the output has no input dependencies
|
||||
// use the size of the vmapped axis in the inputs to expand the output
|
||||
array output = expand_dims(s_outputs[i], out_axes[i]);
|
||||
output = repeat(output, vmap_size, out_axes[i]);
|
||||
outputs.push_back(output);
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
|
Reference in New Issue
Block a user