From 449b43762e3f970576f054e54066123c0f37246e Mon Sep 17 00:00:00 2001 From: Diogo Date: Sun, 7 Jan 2024 12:01:09 -0500 Subject: [PATCH] Add inner / outer op (#348) * inner / outer impl * python tests * ops list and ack * updated descriptions * use test helper * removed dtype check and flatten outer to 1-D * updated docs * just use the reshape to flatten --- ACKNOWLEDGMENTS.md | 2 +- docs/src/python/ops.rst | 2 ++ mlx/ops.cpp | 21 +++++++++++++++---- mlx/ops.h | 6 ++++++ python/src/ops.cpp | 42 ++++++++++++++++++++++++++++++++++++++ python/tests/test_ops.py | 28 +++++++++++++++++++++++++ tests/ops_tests.cpp | 44 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 140 insertions(+), 5 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index b4108657f..dce0ddc81 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 4e399524e..ffeee71da 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -49,6 +49,7 @@ Operations greater greater_equal identity + inner less less_equal linspace @@ -71,6 +72,7 @@ Operations negative ones ones_like + outer partition pad prod diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2d9410d94..aff0d46ed 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2848,10 +2848,6 @@ array tensordot( throw std::invalid_argument( "[tensordot] dims[0] and dims[1] must have the same number of dimensions."); } - if (a.dtype() != b.dtype()) { - throw std::invalid_argument( - "[tensordot] a and b must have the same dtype."); - } int csize = 1; auto x = a; auto y = b; @@ -2905,4 +2901,21 @@ array tensordot( return reshape(matmul(x, y, s), rshape, s); } +array outer(const array& a, const array& b, StreamOrDevice s /* = {} */) { + return multiply( + reshape(a, {static_cast(a.size()), 1}, s), flatten(b, s), s); +} + +array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) { + if (a.ndim() == 0 || b.ndim() == 0) { + return multiply(a, b, s); + } + if (a.shape(-1) != b.shape(-1)) { + throw std::invalid_argument( + "[inner] a and b must have the same last dimension."); + } + + return tensordot(a, b, {{-1}, {-1}}, s); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 0f7b52da4..31c2eb905 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1075,6 +1075,12 @@ array tensordot( const std::pair, std::vector>& dims, StreamOrDevice s = {}); +/** Compute the outer product of two vectors. */ +array outer(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute the inner product of two vectors. */ +array inner(const array& a, const array& b, StreamOrDevice s = {}); + /** Load array map from .safetensors file format */ std::unordered_map load_safetensors( std::shared_ptr in_stream, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 15c8bf69d..f41049b82 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3250,4 +3250,46 @@ void init_ops(py::module_& m) { Returns: result (array): The tensor dot product. )pbdoc"); + + m.def( + "inner", + &inner, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + inner(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array + + Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes. + + Args: + a (array): Input array + b (array): Input array + + Returns: + result (array): The inner product. + )pbdoc"); + + m.def( + "outer", + &outer, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + outer(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array + + Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand. + + Args: + a (array): Input array + b (array): Input array + + Returns: + result (array): The outer product. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 777a23cbe..393c8dc9e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1547,6 +1547,34 @@ class TestOps(mlx_tests.MLXTestCase): dims=([2, 1, 3], [1, 2, 0]), ) + def test_inner(self): + self.assertCmpNumpy([(3,), (3,)], mx.inner, np.inner) + self.assertCmpNumpy([(1, 1, 2), (3, 2)], mx.inner, np.inner) + self.assertCmpNumpy([(2, 3, 4), (4,)], mx.inner, np.inner) + + def test_outer(self): + self.assertCmpNumpy([(3,), (3,)], mx.outer, np.outer) + self.assertCmpNumpy( + [ + mx.ones( + 5, + ), + mx.linspace(-2, 2, 5), + ], + mx.outer, + np.outer, + ) + self.assertCmpNumpy( + [ + 1j * mx.linspace(2, -2, 5), + mx.ones( + 5, + ), + ], + mx.outer, + np.outer, + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 0a1dea0ad..d34497d29 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2315,3 +2315,47 @@ TEST_CASE("tensordot") { {3, 6}); CHECK(array_equal(z, expected).item()); } + +TEST_CASE("outer") { + auto x = arange(1.0, 5.0); + auto y = arange(1.0, 4.0); + auto z = outer(x, y); + auto expected = array( + {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}, {4, 3}); + CHECK(array_equal(z, expected).item()); + + x = ones({5}); + y = linspace(-2., 2., 5); + z = outer(x, y); + expected = array( + {-2., -1., 0., 1., 2., -2., -1., 0., 1., 2., -2., -1., 0., + 1., 2., -2., -1., 0., 1., 2., -2., -1., 0., 1., 2.}, + {5, 5}); + CHECK(array_equal(z, expected).item()); +} + +TEST_CASE("inner") { + CHECK_THROWS_AS( + inner(reshape(arange(5.), {1, 5}), reshape(arange(6.), {2, 3})), + std::invalid_argument); + auto x = array({1., 2., 3.}); + auto y = array({0., 1., 0.}); + auto z = inner(x, y); + CHECK_EQ(z.item(), 2.f); + + x = reshape(arange(24.), {2, 3, 4}); + y = arange(4.); + z = inner(x, y); + auto expected = array({14., 38., 62., 86., 110., 134.}, {2, 3}); + CHECK(array_equal(z, expected).item()); + + x = reshape(arange(2.), {1, 1, 2}); + y = reshape(arange(6.), {3, 2}); + z = inner(x, y); + expected = array({1., 3., 5.}, {1, 1, 3}); + CHECK(array_equal(z, expected).item()); + + z = inner(eye(2), array(7.)); + expected = array({7., 0., 0., 7.}, {2, 2}); + CHECK(array_equal(z, expected).item()); +}