mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-22 01:21:14 +08:00
Add inner / outer op (#348)
* inner / outer impl * python tests * ops list and ack * updated descriptions * use test helper * removed dtype check and flatten outer to 1-D * updated docs * just use the reshape to flatten
This commit is contained in:
parent
6ea6b4258d
commit
449b43762e
@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
|
@ -49,6 +49,7 @@ Operations
|
|||||||
greater
|
greater
|
||||||
greater_equal
|
greater_equal
|
||||||
identity
|
identity
|
||||||
|
inner
|
||||||
less
|
less
|
||||||
less_equal
|
less_equal
|
||||||
linspace
|
linspace
|
||||||
@ -71,6 +72,7 @@ Operations
|
|||||||
negative
|
negative
|
||||||
ones
|
ones
|
||||||
ones_like
|
ones_like
|
||||||
|
outer
|
||||||
partition
|
partition
|
||||||
pad
|
pad
|
||||||
prod
|
prod
|
||||||
|
21
mlx/ops.cpp
21
mlx/ops.cpp
@ -2848,10 +2848,6 @@ array tensordot(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
|
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
|
||||||
}
|
}
|
||||||
if (a.dtype() != b.dtype()) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[tensordot] a and b must have the same dtype.");
|
|
||||||
}
|
|
||||||
int csize = 1;
|
int csize = 1;
|
||||||
auto x = a;
|
auto x = a;
|
||||||
auto y = b;
|
auto y = b;
|
||||||
@ -2905,4 +2901,21 @@ array tensordot(
|
|||||||
return reshape(matmul(x, y, s), rshape, s);
|
return reshape(matmul(x, y, s), rshape, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array outer(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
return multiply(
|
||||||
|
reshape(a, {static_cast<int>(a.size()), 1}, s), flatten(b, s), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||||
|
return multiply(a, b, s);
|
||||||
|
}
|
||||||
|
if (a.shape(-1) != b.shape(-1)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[inner] a and b must have the same last dimension.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensordot(a, b, {{-1}, {-1}}, s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1075,6 +1075,12 @@ array tensordot(
|
|||||||
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Compute the outer product of two vectors. */
|
||||||
|
array outer(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Compute the inner product of two vectors. */
|
||||||
|
array inner(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Load array map from .safetensors file format */
|
/** Load array map from .safetensors file format */
|
||||||
std::unordered_map<std::string, array> load_safetensors(
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
|
@ -3250,4 +3250,46 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): The tensor dot product.
|
result (array): The tensor dot product.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"inner",
|
||||||
|
&inner,
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
inner(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array
|
||||||
|
b (array): Input array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (array): The inner product.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"outer",
|
||||||
|
&outer,
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
outer(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array
|
||||||
|
b (array): Input array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (array): The outer product.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -1547,6 +1547,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
dims=([2, 1, 3], [1, 2, 0]),
|
dims=([2, 1, 3], [1, 2, 0]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_inner(self):
|
||||||
|
self.assertCmpNumpy([(3,), (3,)], mx.inner, np.inner)
|
||||||
|
self.assertCmpNumpy([(1, 1, 2), (3, 2)], mx.inner, np.inner)
|
||||||
|
self.assertCmpNumpy([(2, 3, 4), (4,)], mx.inner, np.inner)
|
||||||
|
|
||||||
|
def test_outer(self):
|
||||||
|
self.assertCmpNumpy([(3,), (3,)], mx.outer, np.outer)
|
||||||
|
self.assertCmpNumpy(
|
||||||
|
[
|
||||||
|
mx.ones(
|
||||||
|
5,
|
||||||
|
),
|
||||||
|
mx.linspace(-2, 2, 5),
|
||||||
|
],
|
||||||
|
mx.outer,
|
||||||
|
np.outer,
|
||||||
|
)
|
||||||
|
self.assertCmpNumpy(
|
||||||
|
[
|
||||||
|
1j * mx.linspace(2, -2, 5),
|
||||||
|
mx.ones(
|
||||||
|
5,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
mx.outer,
|
||||||
|
np.outer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2315,3 +2315,47 @@ TEST_CASE("tensordot") {
|
|||||||
{3, 6});
|
{3, 6});
|
||||||
CHECK(array_equal(z, expected).item<bool>());
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("outer") {
|
||||||
|
auto x = arange(1.0, 5.0);
|
||||||
|
auto y = arange(1.0, 4.0);
|
||||||
|
auto z = outer(x, y);
|
||||||
|
auto expected = array(
|
||||||
|
{1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}, {4, 3});
|
||||||
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
|
|
||||||
|
x = ones({5});
|
||||||
|
y = linspace(-2., 2., 5);
|
||||||
|
z = outer(x, y);
|
||||||
|
expected = array(
|
||||||
|
{-2., -1., 0., 1., 2., -2., -1., 0., 1., 2., -2., -1., 0.,
|
||||||
|
1., 2., -2., -1., 0., 1., 2., -2., -1., 0., 1., 2.},
|
||||||
|
{5, 5});
|
||||||
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("inner") {
|
||||||
|
CHECK_THROWS_AS(
|
||||||
|
inner(reshape(arange(5.), {1, 5}), reshape(arange(6.), {2, 3})),
|
||||||
|
std::invalid_argument);
|
||||||
|
auto x = array({1., 2., 3.});
|
||||||
|
auto y = array({0., 1., 0.});
|
||||||
|
auto z = inner(x, y);
|
||||||
|
CHECK_EQ(z.item<float>(), 2.f);
|
||||||
|
|
||||||
|
x = reshape(arange(24.), {2, 3, 4});
|
||||||
|
y = arange(4.);
|
||||||
|
z = inner(x, y);
|
||||||
|
auto expected = array({14., 38., 62., 86., 110., 134.}, {2, 3});
|
||||||
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
|
|
||||||
|
x = reshape(arange(2.), {1, 1, 2});
|
||||||
|
y = reshape(arange(6.), {3, 2});
|
||||||
|
z = inner(x, y);
|
||||||
|
expected = array({1., 3., 5.}, {1, 1, 3});
|
||||||
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
|
|
||||||
|
z = inner(eye(2), array(7.));
|
||||||
|
expected = array({7., 0., 0., 7.}, {2, 2});
|
||||||
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user