mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -3930,4 +3930,21 @@ std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||
return {{linalg::inv(a, stream())}, {ax}};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> View::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{view(inputs[0], dtype_, stream())}, axes};
|
||||
}
|
||||
|
||||
void View::print(std::ostream& os) {
|
||||
os << "View" << dtype_;
|
||||
}
|
||||
|
||||
bool View::is_equivalent(const Primitive& other) const {
|
||||
const View& a_other = static_cast<const View&>(other);
|
||||
return (dtype_ == a_other.dtype_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user