NumberOfElements for shapeless compile and vmap fixes (#802)

This commit is contained in:
Angelos Katharopoulos
2024-03-13 10:34:14 -07:00
committed by GitHub
parent 29d0c10ee5
commit 76c919b4ec
13 changed files with 289 additions and 72 deletions

View File

@@ -653,6 +653,7 @@ std::vector<array> vmap_replace(
v_axes.push_back(-1);
}
}
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
// For each primitive's outputs add its id, the vout id and the vax
auto outputs = a.outputs();