From ea9090bbc43bb0e58a792be8b43638d68743f178 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Jun 2024 08:05:27 -0700 Subject: [PATCH] Add view op (#1179) * add view primitive * nit * fix view --- docs/src/python/ops.rst | 1 + mlx/backend/common/default_primitives.cpp | 1 - mlx/backend/common/primitives.cpp | 32 ++++++++++++++++++++ mlx/backend/metal/primitives.cpp | 31 +++++++++++++++++++ mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_metal/primitives.cpp | 1 + mlx/ops.cpp | 28 +++++++++++++++++ mlx/ops.h | 1 + mlx/primitives.cpp | 17 +++++++++++ mlx/primitives.h | 16 ++++++++++ python/src/array.cpp | 11 ++++++- python/src/ops.cpp | 37 +++++++++++++++++------ python/tests/test_ops.py | 24 +++++++++++++++ tests/ops_tests.cpp | 12 ++++++++ 14 files changed, 202 insertions(+), 11 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 745d12d50..0b50ea244 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -156,6 +156,7 @@ Operations tril triu var + view where zeros zeros_like diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 6c9648461..5164d9579 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -5,7 +5,6 @@ #else #include #endif - #include #include "mlx/array.h" diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 1d1f66ce9..29c4dcfd0 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -590,4 +590,36 @@ void Tanh::eval(const std::vector& inputs, array& out) { } } +void View::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + auto ibytes = size_of(in.dtype()); + auto obytes = size_of(out.dtype()); + // Conditions for buffer copying (disjunction): + // - type size is the same + // - type size is smaller and the last axis is contiguous + // - the entire array is row contiguous + if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 || + in.flags().row_contiguous) { + auto strides = in.strides(); + for (int i = 0; i < strides.size() - 1; ++i) { + strides[i] *= ibytes; + strides[i] /= obytes; + } + out.copy_shared_buffer( + in, strides, in.flags(), in.data_size() * obytes / ibytes); + } else { + auto tmp = array(in.shape(), in.dtype(), nullptr, {}); + tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); + copy_inplace(in, tmp, CopyType::General); + + auto flags = out.flags(); + flags.contiguous = true; + flags.row_contiguous = true; + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + out.move_shared_buffer(tmp, out.strides(), flags, out.size()); + } +} + } // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 8d73edc48..3e249a1d4 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -422,4 +422,35 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } +void View::eval_gpu(const std::vector& inputs, array& out) { + auto& in = inputs[0]; + auto ibytes = size_of(in.dtype()); + auto obytes = size_of(out.dtype()); + // Conditions for buffer copying (disjunction): + // - type size is the same + // - type size is smaller and the last axis is contiguous + // - the entire array is row contiguous + if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 || + in.flags().row_contiguous) { + auto strides = in.strides(); + for (int i = 0; i < strides.size() - 1; ++i) { + strides[i] *= ibytes; + strides[i] /= obytes; + } + out.copy_shared_buffer( + in, strides, in.flags(), in.data_size() * ibytes / obytes); + } else { + auto tmp = array(in.shape(), in.dtype(), nullptr, {}); + tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); + copy_gpu_inplace(in, tmp, CopyType::General, stream()); + + auto flags = out.flags(); + flags.contiguous = true; + flags.row_contiguous = true; + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + out.move_shared_buffer(tmp, out.strides(), flags, out.size()); + } +} + } // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index e5e733555..f652a556c 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -106,5 +106,6 @@ NO_CPU(Tan) NO_CPU(Tanh) NO_CPU(Transpose) NO_CPU(Inverse) +NO_CPU(View) } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 3398a023b..a5dd87369 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -108,6 +108,7 @@ NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Inverse) NO_GPU(Cholesky) +NO_GPU(View) namespace fast { NO_GPU_MULTI(LayerNorm) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f0d155c96..d38b8e6f7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4378,4 +4378,32 @@ array operator>>(const array& a, const array& b) { return right_shift(a, b); } +array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) { + if (a.dtype() == dtype) { + return a; + } + auto out_shape = a.shape(); + auto ibytes = size_of(a.dtype()); + auto obytes = size_of(dtype); + if (a.ndim() == 0 && ibytes != obytes) { + throw std::invalid_argument( + "[view] Changing the type of a scalar is only allowed" + " for types with the same size."); + } else { + if (ibytes < obytes) { + if (out_shape.back() % (obytes / ibytes) != 0) { + throw std::invalid_argument( + "[view] When viewing as a larger dtype, the size in bytes of the last" + " axis must be a multiple of the requested type size."); + } + out_shape.back() /= (obytes / ibytes); + } else { + // Type size ratios are always integers + out_shape.back() *= (ibytes / obytes); + } + } + return array( + out_shape, dtype, std::make_shared(to_stream(s), dtype), {a}); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 17bc11255..934edf619 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1300,6 +1300,7 @@ array operator<<(const array& a, const array& b); array right_shift(const array& a, const array& b, StreamOrDevice s = {}); array operator>>(const array& a, const array& b); +array view(const array& a, const Dtype& dtype, StreamOrDevice s = {}); /** @} */ } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 21cdcb4e9..a4fdc26f6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3930,4 +3930,21 @@ std::pair, std::vector> Inverse::vmap( return {{linalg::inv(a, stream())}, {ax}}; } +std::pair, std::vector> View::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{view(inputs[0], dtype_, stream())}, axes}; +} + +void View::print(std::ostream& os) { + os << "View" << dtype_; +} + +bool View::is_equivalent(const Primitive& other) const { + const View& a_other = static_cast(other); + return (dtype_ == a_other.dtype_); +} + } // namespace mlx::core diff --git a/mlx/primitives.h b/mlx/primitives.h index 517ce0b29..88f5def4b 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2021,6 +2021,22 @@ class Uniform : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class View : public UnaryPrimitive { + public: + explicit View(Stream stream, Dtype dtype) + : UnaryPrimitive(stream), dtype_(dtype) {}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + void print(std::ostream& os) override; + bool is_equivalent(const Primitive& other) const override; + + private: + Dtype dtype_; +}; + class Transpose : public UnaryPrimitive { public: explicit Transpose(Stream stream, const std::vector& axes) diff --git a/python/src/array.cpp b/python/src/array.cpp index 15825a926..da70a114e 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1276,5 +1276,14 @@ void init_array(nb::module_& m) { }, nb::kw_only(), "stream"_a = nb::none(), - "See :func:`conj`."); + "See :func:`conj`.") + .def( + "view", + [](const ScalarOrArray& a, const Dtype& dtype, StreamOrDevice s) { + return view(to_array(a), dtype, s); + }, + "dtype"_a, + nb::kw_only(), + "stream"_a = nb::none(), + "See :func:`view`."); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index b919faed6..8bfe4e124 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3159,8 +3159,6 @@ void init_ops(nb::module_& m) { R"pbdoc( 1D convolution over an input with several channels - Note: Only the default ``groups=1`` is currently supported. - Args: input (array): input array of shape (``N``, ``H``, ``C_in``) weight (array): weight array of shape (``C_out``, ``H``, ``C_in``) @@ -3219,8 +3217,6 @@ void init_ops(nb::module_& m) { R"pbdoc( 2D convolution over an input with several channels - Note: Only the default ``groups=1`` is currently supported. - Args: input (array): input array of shape ``(N, H, W, C_in)`` weight (array): weight array of shape ``(C_out, H, W, C_in)`` @@ -3390,11 +3386,6 @@ void init_ops(nb::module_& m) { R"pbdoc( General convolution over an input with several channels - .. note:: - - * Only 1d and 2d convolutions are supported at the moment - * the default ``groups=1`` is currently supported. - Args: input (array): Input array of shape ``(N, ..., C_in)`` weight (array): Weight array of shape ``(C_out, ..., C_in)`` @@ -4356,4 +4347,32 @@ void init_ops(nb::module_& m) { Returns: array: The bitwise right shift ``a >> b``. )pbdoc"); + m.def( + "view", + [](const ScalarOrArray& a, const Dtype& dtype, StreamOrDevice s) { + return view(to_array(a), dtype, s); + }, + nb::arg(), + "dtype"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def view(a: Union[scalar, array], dtype: Dtype, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + View the array as a different type. + + The output shape changes along the last axis if the input array's + type and the input ``dtype`` do not have the same size. + + Note: the view op does not imply that the input and output arrays share + their underlying data. The view only gaurantees that the binary + representation of each element (or group of elements) is the same. + + Args: + a (array): Input array or scalar. + dtype (Dtype): The data type to change to. + + Returns: + array: The array with the new type. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 8ee2412bf..209ddcae2 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2333,6 +2333,30 @@ class TestOps(mlx_tests.MLXTestCase): out_np = a.conj() self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) + def test_view(self): + a = mx.random.randint(shape=(4, 2, 4), low=-100, high=100) + a_np = np.array(a) + + for t in ["bool_", "int16", "float32", "int64"]: + out = a.view(getattr(mx, t)) + expected = a_np.view(getattr(np, t)) + self.assertTrue(np.array_equal(out, expected, equal_nan=True)) + + # Irregular strides + a = mx.random.randint(shape=(2, 4), low=-100, high=100) + a = mx.broadcast_to(a, shape=(4, 2, 4)) + + for t in ["bool_", "int16", "float32", "int64"]: + out = a.view(getattr(mx, t)) + a_out = out.view(mx.int32) + self.assertTrue(mx.array_equal(a_out, a, equal_nan=True)) + + a = mx.random.randint(shape=(4, 4), low=-100, high=100).T + for t in ["bool_", "int16", "float32", "int64"]: + out = a.view(getattr(mx, t)) + a_out = out.view(mx.int32) + self.assertTrue(mx.array_equal(a_out, a, equal_nan=True)) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 7aa3a3450..f93341f96 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3473,3 +3473,15 @@ TEST_CASE("test trace") { auto out4 = trace(in, 0, 1, 2, float32); CHECK(array_equal(out4, array({3, 11}, {2})).item()); } + +TEST_CASE("test view") { + auto in = array(3); + CHECK_THROWS(view(in, int64)); + + in = array({1, 2, 3}); + CHECK_THROWS(view(in, int64)); + + in = array({1, 2, 3, 4}, int64); + auto out = view(in, int32); + CHECK(array_equal(out, array({1, 0, 2, 0, 3, 0, 4, 0})).item()); +}