Add view op (#1179)

* add view primitive

* nit

* fix view
This commit is contained in:
Awni Hannun
2024-06-04 08:05:27 -07:00
committed by GitHub
parent 81def6ac76
commit ea9090bbc4
14 changed files with 202 additions and 11 deletions

View File

@@ -4378,4 +4378,32 @@ array operator>>(const array& a, const array& b) {
return right_shift(a, b);
}
array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
if (a.dtype() == dtype) {
return a;
}
auto out_shape = a.shape();
auto ibytes = size_of(a.dtype());
auto obytes = size_of(dtype);
if (a.ndim() == 0 && ibytes != obytes) {
throw std::invalid_argument(
"[view] Changing the type of a scalar is only allowed"
" for types with the same size.");
} else {
if (ibytes < obytes) {
if (out_shape.back() % (obytes / ibytes) != 0) {
throw std::invalid_argument(
"[view] When viewing as a larger dtype, the size in bytes of the last"
" axis must be a multiple of the requested type size.");
}
out_shape.back() /= (obytes / ibytes);
} else {
// Type size ratios are always integers
out_shape.back() *= (ibytes / obytes);
}
}
return array(
out_shape, dtype, std::make_shared<View>(to_stream(s), dtype), {a});
}
} // namespace mlx::core