mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Implement the 'where' primitive for conditional selection (#664)
This commit is contained in:
@@ -47,6 +47,10 @@ bool is_binary(const Primitive& p) {
|
||||
typeid(p) == typeid(Subtract));
|
||||
}
|
||||
|
||||
bool is_ternary(const Primitive& p) {
|
||||
return typeid(p) == typeid(Select);
|
||||
}
|
||||
|
||||
bool is_broadcast(const Primitive& p) {
|
||||
return typeid(p) == typeid(Broadcast);
|
||||
}
|
||||
@@ -60,14 +64,16 @@ bool is_reduction(const Primitive& p) {
|
||||
}
|
||||
|
||||
bool is_fusable(const Primitive& p) {
|
||||
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
||||
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
|
||||
is_noop(p);
|
||||
}
|
||||
|
||||
bool allows_shapeless(const Primitive& p) {
|
||||
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
|
||||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
||||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
|
||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
|
||||
typeid(p) == typeid(Select);
|
||||
}
|
||||
|
||||
Compiled::Compiled(
|
||||
|
||||
Reference in New Issue
Block a user