mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 00:31:12 +08:00
Add inner / outer op (#348)
* inner / outer impl * python tests * ops list and ack * updated descriptions * use test helper * removed dtype check and flatten outer to 1-D * updated docs * just use the reshape to flatten
This commit is contained in:
parent
6ea6b4258d
commit
449b43762e
@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
|
@ -49,6 +49,7 @@ Operations
|
||||
greater
|
||||
greater_equal
|
||||
identity
|
||||
inner
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
@ -71,6 +72,7 @@ Operations
|
||||
negative
|
||||
ones
|
||||
ones_like
|
||||
outer
|
||||
partition
|
||||
pad
|
||||
prod
|
||||
|
21
mlx/ops.cpp
21
mlx/ops.cpp
@ -2848,10 +2848,6 @@ array tensordot(
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
|
||||
}
|
||||
if (a.dtype() != b.dtype()) {
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] a and b must have the same dtype.");
|
||||
}
|
||||
int csize = 1;
|
||||
auto x = a;
|
||||
auto y = b;
|
||||
@ -2905,4 +2901,21 @@ array tensordot(
|
||||
return reshape(matmul(x, y, s), rshape, s);
|
||||
}
|
||||
|
||||
array outer(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
return multiply(
|
||||
reshape(a, {static_cast<int>(a.size()), 1}, s), flatten(b, s), s);
|
||||
}
|
||||
|
||||
array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||
return multiply(a, b, s);
|
||||
}
|
||||
if (a.shape(-1) != b.shape(-1)) {
|
||||
throw std::invalid_argument(
|
||||
"[inner] a and b must have the same last dimension.");
|
||||
}
|
||||
|
||||
return tensordot(a, b, {{-1}, {-1}}, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1075,6 +1075,12 @@ array tensordot(
|
||||
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute the outer product of two vectors. */
|
||||
array outer(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Compute the inner product of two vectors. */
|
||||
array inner(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
|
@ -3250,4 +3250,46 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): The tensor dot product.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"inner",
|
||||
&inner,
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
inner(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
b (array): Input array
|
||||
|
||||
Returns:
|
||||
result (array): The inner product.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"outer",
|
||||
&outer,
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
outer(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
b (array): Input array
|
||||
|
||||
Returns:
|
||||
result (array): The outer product.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -1547,6 +1547,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
dims=([2, 1, 3], [1, 2, 0]),
|
||||
)
|
||||
|
||||
def test_inner(self):
|
||||
self.assertCmpNumpy([(3,), (3,)], mx.inner, np.inner)
|
||||
self.assertCmpNumpy([(1, 1, 2), (3, 2)], mx.inner, np.inner)
|
||||
self.assertCmpNumpy([(2, 3, 4), (4,)], mx.inner, np.inner)
|
||||
|
||||
def test_outer(self):
|
||||
self.assertCmpNumpy([(3,), (3,)], mx.outer, np.outer)
|
||||
self.assertCmpNumpy(
|
||||
[
|
||||
mx.ones(
|
||||
5,
|
||||
),
|
||||
mx.linspace(-2, 2, 5),
|
||||
],
|
||||
mx.outer,
|
||||
np.outer,
|
||||
)
|
||||
self.assertCmpNumpy(
|
||||
[
|
||||
1j * mx.linspace(2, -2, 5),
|
||||
mx.ones(
|
||||
5,
|
||||
),
|
||||
],
|
||||
mx.outer,
|
||||
np.outer,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -2315,3 +2315,47 @@ TEST_CASE("tensordot") {
|
||||
{3, 6});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("outer") {
|
||||
auto x = arange(1.0, 5.0);
|
||||
auto y = arange(1.0, 4.0);
|
||||
auto z = outer(x, y);
|
||||
auto expected = array(
|
||||
{1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}, {4, 3});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
|
||||
x = ones({5});
|
||||
y = linspace(-2., 2., 5);
|
||||
z = outer(x, y);
|
||||
expected = array(
|
||||
{-2., -1., 0., 1., 2., -2., -1., 0., 1., 2., -2., -1., 0.,
|
||||
1., 2., -2., -1., 0., 1., 2., -2., -1., 0., 1., 2.},
|
||||
{5, 5});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("inner") {
|
||||
CHECK_THROWS_AS(
|
||||
inner(reshape(arange(5.), {1, 5}), reshape(arange(6.), {2, 3})),
|
||||
std::invalid_argument);
|
||||
auto x = array({1., 2., 3.});
|
||||
auto y = array({0., 1., 0.});
|
||||
auto z = inner(x, y);
|
||||
CHECK_EQ(z.item<float>(), 2.f);
|
||||
|
||||
x = reshape(arange(24.), {2, 3, 4});
|
||||
y = arange(4.);
|
||||
z = inner(x, y);
|
||||
auto expected = array({14., 38., 62., 86., 110., 134.}, {2, 3});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
|
||||
x = reshape(arange(2.), {1, 1, 2});
|
||||
y = reshape(arange(6.), {3, 2});
|
||||
z = inner(x, y);
|
||||
expected = array({1., 3., 5.}, {1, 1, 3});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
|
||||
z = inner(eye(2), array(7.));
|
||||
expected = array({7., 0., 0., 7.}, {2, 2});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user