Add inner / outer op (#348)

* inner / outer impl

* python tests

* ops list and ack

* updated descriptions

* use test helper

* removed dtype check and flatten outer to 1-D

* updated docs

* just use the reshape to flatten
This commit is contained in:
Diogo 2024-01-07 12:01:09 -05:00 committed by GitHub
parent 6ea6b4258d
commit 449b43762e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 140 additions and 5 deletions

View File

@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">

View File

@ -49,6 +49,7 @@ Operations
greater
greater_equal
identity
inner
less
less_equal
linspace
@ -71,6 +72,7 @@ Operations
negative
ones
ones_like
outer
partition
pad
prod

View File

@ -2848,10 +2848,6 @@ array tensordot(
throw std::invalid_argument(
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
}
if (a.dtype() != b.dtype()) {
throw std::invalid_argument(
"[tensordot] a and b must have the same dtype.");
}
int csize = 1;
auto x = a;
auto y = b;
@ -2905,4 +2901,21 @@ array tensordot(
return reshape(matmul(x, y, s), rshape, s);
}
array outer(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return multiply(
reshape(a, {static_cast<int>(a.size()), 1}, s), flatten(b, s), s);
}
array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (a.ndim() == 0 || b.ndim() == 0) {
return multiply(a, b, s);
}
if (a.shape(-1) != b.shape(-1)) {
throw std::invalid_argument(
"[inner] a and b must have the same last dimension.");
}
return tensordot(a, b, {{-1}, {-1}}, s);
}
} // namespace mlx::core

View File

@ -1075,6 +1075,12 @@ array tensordot(
const std::pair<std::vector<int>, std::vector<int>>& dims,
StreamOrDevice s = {});
/** Compute the outer product of two vectors. */
array outer(const array& a, const array& b, StreamOrDevice s = {});
/** Compute the inner product of two vectors. */
array inner(const array& a, const array& b, StreamOrDevice s = {});
/** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream,

View File

@ -3250,4 +3250,46 @@ void init_ops(py::module_& m) {
Returns:
result (array): The tensor dot product.
)pbdoc");
m.def(
"inner",
&inner,
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
inner(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes.
Args:
a (array): Input array
b (array): Input array
Returns:
result (array): The inner product.
)pbdoc");
m.def(
"outer",
&outer,
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
outer(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand.
Args:
a (array): Input array
b (array): Input array
Returns:
result (array): The outer product.
)pbdoc");
}

View File

@ -1547,6 +1547,34 @@ class TestOps(mlx_tests.MLXTestCase):
dims=([2, 1, 3], [1, 2, 0]),
)
def test_inner(self):
self.assertCmpNumpy([(3,), (3,)], mx.inner, np.inner)
self.assertCmpNumpy([(1, 1, 2), (3, 2)], mx.inner, np.inner)
self.assertCmpNumpy([(2, 3, 4), (4,)], mx.inner, np.inner)
def test_outer(self):
self.assertCmpNumpy([(3,), (3,)], mx.outer, np.outer)
self.assertCmpNumpy(
[
mx.ones(
5,
),
mx.linspace(-2, 2, 5),
],
mx.outer,
np.outer,
)
self.assertCmpNumpy(
[
1j * mx.linspace(2, -2, 5),
mx.ones(
5,
),
],
mx.outer,
np.outer,
)
if __name__ == "__main__":
unittest.main()

View File

@ -2315,3 +2315,47 @@ TEST_CASE("tensordot") {
{3, 6});
CHECK(array_equal(z, expected).item<bool>());
}
TEST_CASE("outer") {
auto x = arange(1.0, 5.0);
auto y = arange(1.0, 4.0);
auto z = outer(x, y);
auto expected = array(
{1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}, {4, 3});
CHECK(array_equal(z, expected).item<bool>());
x = ones({5});
y = linspace(-2., 2., 5);
z = outer(x, y);
expected = array(
{-2., -1., 0., 1., 2., -2., -1., 0., 1., 2., -2., -1., 0.,
1., 2., -2., -1., 0., 1., 2., -2., -1., 0., 1., 2.},
{5, 5});
CHECK(array_equal(z, expected).item<bool>());
}
TEST_CASE("inner") {
CHECK_THROWS_AS(
inner(reshape(arange(5.), {1, 5}), reshape(arange(6.), {2, 3})),
std::invalid_argument);
auto x = array({1., 2., 3.});
auto y = array({0., 1., 0.});
auto z = inner(x, y);
CHECK_EQ(z.item<float>(), 2.f);
x = reshape(arange(24.), {2, 3, 4});
y = arange(4.);
z = inner(x, y);
auto expected = array({14., 38., 62., 86., 110., 134.}, {2, 3});
CHECK(array_equal(z, expected).item<bool>());
x = reshape(arange(2.), {1, 1, 2});
y = reshape(arange(6.), {3, 2});
z = inner(x, y);
expected = array({1., 3., 5.}, {1, 1, 3});
CHECK(array_equal(z, expected).item<bool>());
z = inner(eye(2), array(7.));
expected = array({7., 0., 0., 7.}, {2, 2});
CHECK(array_equal(z, expected).item<bool>());
}