Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -82,6 +82,7 @@ bool allows_shapeless(const Primitive& p) {
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) ||
typeid(p) == typeid(Flatten) || typeid(p) == typeid(Unflatten) ||
typeid(p) == typeid(fast::AffineQuantize) ||
typeid(p) == typeid(fast::LayerNorm) ||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||