Implement the 'where' primitive for conditional selection (#664)

This commit is contained in:
Rifur13
2024-02-22 18:10:48 -05:00
committed by GitHub
parent ad4a45e615
commit 126c9869c8
23 changed files with 991 additions and 56 deletions

View File

@@ -73,6 +73,7 @@ void time_unary_ops() {
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
@@ -84,7 +85,9 @@ void time_binary_ops() {
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
@@ -93,7 +96,9 @@ void time_binary_ops() {
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
@@ -101,6 +106,7 @@ void time_binary_ops() {
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}
void time_strided_ops() {