Implement the 'where' primitive for conditional selection (#664)

This commit is contained in:
Rifur13
2024-02-22 18:10:48 -05:00
committed by GitHub
parent ad4a45e615
commit 126c9869c8
23 changed files with 991 additions and 56 deletions

View File

@@ -47,6 +47,10 @@ bool is_binary(const Primitive& p) {
typeid(p) == typeid(Subtract));
}
bool is_ternary(const Primitive& p) {
return typeid(p) == typeid(Select);
}
bool is_broadcast(const Primitive& p) {
return typeid(p) == typeid(Broadcast);
}
@@ -60,14 +64,16 @@ bool is_reduction(const Primitive& p) {
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
is_noop(p);
}
bool allows_shapeless(const Primitive& p) {
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
typeid(p) == typeid(Select);
}
Compiled::Compiled(