Multi output primitives (#330)

* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-01-08 16:39:08 -08:00
committed by GitHub
parent f45f70f133
commit f099ebe535
26 changed files with 2313 additions and 1039 deletions

View File

@@ -233,6 +233,20 @@ void time_gather_scatter() {
TIME(single_element_add);
}
void time_divmod() {
auto a = random::normal({1000});
auto b = random::normal({1000});
eval({a, b});
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops();
@@ -246,4 +260,5 @@ int main() {
time_matmul();
time_reductions();
time_gather_scatter();
time_divmod();
}