mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Flatten and unflatten (#1692)
* flatten and unflatten * fix grad * fix shape infer * use squeeze + unsqueeze in get_item
This commit is contained in:
@@ -82,6 +82,7 @@ bool allows_shapeless(const Primitive& p) {
|
||||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
||||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
|
||||
typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) ||
|
||||
typeid(p) == typeid(Flatten) || typeid(p) == typeid(Unflatten) ||
|
||||
typeid(p) == typeid(fast::AffineQuantize) ||
|
||||
typeid(p) == typeid(fast::LayerNorm) ||
|
||||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
||||
|
||||
Reference in New Issue
Block a user