Add view op (#1179)

* add view primitive

* nit

* fix view
This commit is contained in:
Awni Hannun
2024-06-04 08:05:27 -07:00
committed by GitHub
parent 81def6ac76
commit ea9090bbc4
14 changed files with 202 additions and 11 deletions

View File

@@ -3930,4 +3930,21 @@ std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
return {{linalg::inv(a, stream())}, {ax}};
}
std::pair<std::vector<array>, std::vector<int>> View::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{view(inputs[0], dtype_, stream())}, axes};
}
void View::print(std::ostream& os) {
os << "View" << dtype_;
}
bool View::is_equivalent(const Primitive& other) const {
const View& a_other = static_cast<const View&>(other);
return (dtype_ == a_other.dtype_);
}
} // namespace mlx::core