Add view op (#1179)

* add view primitive

* nit

* fix view
This commit is contained in:
Awni Hannun 2024-06-04 08:05:27 -07:00 committed by GitHub
parent 81def6ac76
commit ea9090bbc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 202 additions and 11 deletions

View File

@ -156,6 +156,7 @@ Operations
tril
triu
var
view
where
zeros
zeros_like

View File

@ -5,7 +5,6 @@
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"

View File

@ -590,4 +590,36 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
}
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * obytes / ibytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
copy_inplace(in, tmp, CopyType::General);
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@ -422,4 +422,35 @@ void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@ -106,5 +106,6 @@ NO_CPU(Tan)
NO_CPU(Tanh)
NO_CPU(Transpose)
NO_CPU(Inverse)
NO_CPU(View)
} // namespace mlx::core

View File

@ -108,6 +108,7 @@ NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU(View)
namespace fast {
NO_GPU_MULTI(LayerNorm)

View File

@ -4378,4 +4378,32 @@ array operator>>(const array& a, const array& b) {
return right_shift(a, b);
}
array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
if (a.dtype() == dtype) {
return a;
}
auto out_shape = a.shape();
auto ibytes = size_of(a.dtype());
auto obytes = size_of(dtype);
if (a.ndim() == 0 && ibytes != obytes) {
throw std::invalid_argument(
"[view] Changing the type of a scalar is only allowed"
" for types with the same size.");
} else {
if (ibytes < obytes) {
if (out_shape.back() % (obytes / ibytes) != 0) {
throw std::invalid_argument(
"[view] When viewing as a larger dtype, the size in bytes of the last"
" axis must be a multiple of the requested type size.");
}
out_shape.back() /= (obytes / ibytes);
} else {
// Type size ratios are always integers
out_shape.back() *= (ibytes / obytes);
}
}
return array(
out_shape, dtype, std::make_shared<View>(to_stream(s), dtype), {a});
}
} // namespace mlx::core

View File

@ -1300,6 +1300,7 @@ array operator<<(const array& a, const array& b);
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator>>(const array& a, const array& b);
array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
/** @} */
} // namespace mlx::core

View File

@ -3930,4 +3930,21 @@ std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
return {{linalg::inv(a, stream())}, {ax}};
}
std::pair<std::vector<array>, std::vector<int>> View::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{view(inputs[0], dtype_, stream())}, axes};
}
void View::print(std::ostream& os) {
os << "View" << dtype_;
}
bool View::is_equivalent(const Primitive& other) const {
const View& a_other = static_cast<const View&>(other);
return (dtype_ == a_other.dtype_);
}
} // namespace mlx::core

View File

@ -2021,6 +2021,22 @@ class Uniform : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class View : public UnaryPrimitive {
public:
explicit View(Stream stream, Dtype dtype)
: UnaryPrimitive(stream), dtype_(dtype) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
void print(std::ostream& os) override;
bool is_equivalent(const Primitive& other) const override;
private:
Dtype dtype_;
};
class Transpose : public UnaryPrimitive {
public:
explicit Transpose(Stream stream, const std::vector<int>& axes)

View File

@ -1276,5 +1276,14 @@ void init_array(nb::module_& m) {
},
nb::kw_only(),
"stream"_a = nb::none(),
"See :func:`conj`.");
"See :func:`conj`.")
.def(
"view",
[](const ScalarOrArray& a, const Dtype& dtype, StreamOrDevice s) {
return view(to_array(a), dtype, s);
},
"dtype"_a,
nb::kw_only(),
"stream"_a = nb::none(),
"See :func:`view`.");
}

View File

@ -3159,8 +3159,6 @@ void init_ops(nb::module_& m) {
R"pbdoc(
1D convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): input array of shape (``N``, ``H``, ``C_in``)
weight (array): weight array of shape (``C_out``, ``H``, ``C_in``)
@ -3219,8 +3217,6 @@ void init_ops(nb::module_& m) {
R"pbdoc(
2D convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): input array of shape ``(N, H, W, C_in)``
weight (array): weight array of shape ``(C_out, H, W, C_in)``
@ -3390,11 +3386,6 @@ void init_ops(nb::module_& m) {
R"pbdoc(
General convolution over an input with several channels
.. note::
* Only 1d and 2d convolutions are supported at the moment
* the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, ..., C_in)``
weight (array): Weight array of shape ``(C_out, ..., C_in)``
@ -4356,4 +4347,32 @@ void init_ops(nb::module_& m) {
Returns:
array: The bitwise right shift ``a >> b``.
)pbdoc");
m.def(
"view",
[](const ScalarOrArray& a, const Dtype& dtype, StreamOrDevice s) {
return view(to_array(a), dtype, s);
},
nb::arg(),
"dtype"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def view(a: Union[scalar, array], dtype: Dtype, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
View the array as a different type.
The output shape changes along the last axis if the input array's
type and the input ``dtype`` do not have the same size.
Note: the view op does not imply that the input and output arrays share
their underlying data. The view only gaurantees that the binary
representation of each element (or group of elements) is the same.
Args:
a (array): Input array or scalar.
dtype (Dtype): The data type to change to.
Returns:
array: The array with the new type.
)pbdoc");
}

View File

@ -2333,6 +2333,30 @@ class TestOps(mlx_tests.MLXTestCase):
out_np = a.conj()
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
def test_view(self):
a = mx.random.randint(shape=(4, 2, 4), low=-100, high=100)
a_np = np.array(a)
for t in ["bool_", "int16", "float32", "int64"]:
out = a.view(getattr(mx, t))
expected = a_np.view(getattr(np, t))
self.assertTrue(np.array_equal(out, expected, equal_nan=True))
# Irregular strides
a = mx.random.randint(shape=(2, 4), low=-100, high=100)
a = mx.broadcast_to(a, shape=(4, 2, 4))
for t in ["bool_", "int16", "float32", "int64"]:
out = a.view(getattr(mx, t))
a_out = out.view(mx.int32)
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
a = mx.random.randint(shape=(4, 4), low=-100, high=100).T
for t in ["bool_", "int16", "float32", "int64"]:
out = a.view(getattr(mx, t))
a_out = out.view(mx.int32)
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
if __name__ == "__main__":
unittest.main()

View File

@ -3473,3 +3473,15 @@ TEST_CASE("test trace") {
auto out4 = trace(in, 0, 1, 2, float32);
CHECK(array_equal(out4, array({3, 11}, {2})).item<bool>());
}
TEST_CASE("test view") {
auto in = array(3);
CHECK_THROWS(view(in, int64));
in = array({1, 2, 3});
CHECK_THROWS(view(in, int64));
in = array({1, 2, 3, 4}, int64);
auto out = view(in, int32);
CHECK(array_equal(out, array({1, 0, 2, 0, 3, 0, 4, 0})).item<bool>());
}