mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
28
mlx/ops.cpp
28
mlx/ops.cpp
@@ -4378,4 +4378,32 @@ array operator>>(const array& a, const array& b) {
|
||||
return right_shift(a, b);
|
||||
}
|
||||
|
||||
array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
|
||||
if (a.dtype() == dtype) {
|
||||
return a;
|
||||
}
|
||||
auto out_shape = a.shape();
|
||||
auto ibytes = size_of(a.dtype());
|
||||
auto obytes = size_of(dtype);
|
||||
if (a.ndim() == 0 && ibytes != obytes) {
|
||||
throw std::invalid_argument(
|
||||
"[view] Changing the type of a scalar is only allowed"
|
||||
" for types with the same size.");
|
||||
} else {
|
||||
if (ibytes < obytes) {
|
||||
if (out_shape.back() % (obytes / ibytes) != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[view] When viewing as a larger dtype, the size in bytes of the last"
|
||||
" axis must be a multiple of the requested type size.");
|
||||
}
|
||||
out_shape.back() /= (obytes / ibytes);
|
||||
} else {
|
||||
// Type size ratios are always integers
|
||||
out_shape.back() *= (ibytes / obytes);
|
||||
}
|
||||
}
|
||||
return array(
|
||||
out_shape, dtype, std::make_shared<View>(to_stream(s), dtype), {a});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user