mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:24:49 +08:00
Add Tensordot op (#344)
This commit is contained in:
parent
af66a09bde
commit
0782a4573a
@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added tri, tril, triu and safetensor support
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support
|
||||||
|
|
||||||
# Third-Party Software
|
# Third-Party Software
|
||||||
|
|
||||||
|
@ -104,6 +104,7 @@ Operations
|
|||||||
take_along_axis
|
take_along_axis
|
||||||
tan
|
tan
|
||||||
tanh
|
tanh
|
||||||
|
tensordot
|
||||||
transpose
|
transpose
|
||||||
tri
|
tri
|
||||||
tril
|
tril
|
||||||
|
90
mlx/ops.cpp
90
mlx/ops.cpp
@ -2793,4 +2793,94 @@ array dequantize(
|
|||||||
return w_full;
|
return w_full;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array tensordot(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const int dims /* = 2 */,
|
||||||
|
StreamOrDevice s /* = {} */
|
||||||
|
) {
|
||||||
|
if (dims < 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tensordot] dims must be greater or equal to 0.");
|
||||||
|
}
|
||||||
|
if (dims > std::min(a.ndim(), b.ndim())) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tensordot] dims must be less than the number of dimensions of a and b.");
|
||||||
|
}
|
||||||
|
std::vector<int> adims;
|
||||||
|
std::vector<int> bdims;
|
||||||
|
for (int i = 0; i < dims; i++) {
|
||||||
|
bdims.emplace_back(i);
|
||||||
|
adims.emplace_back(-dims + i);
|
||||||
|
}
|
||||||
|
return tensordot(a, b, {adims, bdims}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array tensordot(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
||||||
|
StreamOrDevice s /* = {} */
|
||||||
|
) {
|
||||||
|
if (dims.first.size() != dims.second.size()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
|
||||||
|
}
|
||||||
|
if (a.dtype() != b.dtype()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tensordot] a and b must have the same dtype.");
|
||||||
|
}
|
||||||
|
int csize = 1;
|
||||||
|
auto x = a;
|
||||||
|
auto y = b;
|
||||||
|
for (int i = 0; i < dims.first.size(); i++) {
|
||||||
|
if (x.shape(dims.first.at(i)) == y.shape(dims.second.at(i))) {
|
||||||
|
csize *= x.shape(dims.first.at(i));
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tensordot] a and b must have the same shape on the contracted axes.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<bool> cdims1(x.ndim(), false);
|
||||||
|
std::vector<bool> cdims2(y.ndim(), false);
|
||||||
|
for (const auto n : dims.first) {
|
||||||
|
int n_ = (n < 0) ? n + x.ndim() : n;
|
||||||
|
cdims1[n_] = true;
|
||||||
|
}
|
||||||
|
for (const auto n : dims.second) {
|
||||||
|
int n_ = (n < 0) ? n + y.ndim() : n;
|
||||||
|
cdims2[n_] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> t1;
|
||||||
|
std::vector<int> t2;
|
||||||
|
std::vector<int> rshape;
|
||||||
|
int size1 = 1;
|
||||||
|
int size2 = 1;
|
||||||
|
for (int i = 0; i < a.ndim(); i++) {
|
||||||
|
if (!cdims1[i]) {
|
||||||
|
t1.emplace_back(i);
|
||||||
|
size1 *= a.shape(i);
|
||||||
|
rshape.emplace_back(a.shape(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto x : dims.first) {
|
||||||
|
t1.emplace_back(x);
|
||||||
|
}
|
||||||
|
for (const auto x : dims.second) {
|
||||||
|
t2.emplace_back(x);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < b.ndim(); i++) {
|
||||||
|
if (!cdims2[i]) {
|
||||||
|
t2.emplace_back(i);
|
||||||
|
size2 *= b.shape(i);
|
||||||
|
rshape.emplace_back(b.shape(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x = reshape(transpose(x, t1, s), {size1, csize}, s);
|
||||||
|
y = reshape(transpose(y, t2, s), {csize, size2}, s);
|
||||||
|
return reshape(matmul(x, y, s), rshape, s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
13
mlx/ops.h
13
mlx/ops.h
@ -1061,6 +1061,19 @@ array dequantize(
|
|||||||
int bits = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** TensorDot returns a contraction of a and b over multiple dimensions. */
|
||||||
|
array tensordot(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const int dims = 2,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array tensordot(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Load array map from .safetensors file format */
|
/** Load array map from .safetensors file format */
|
||||||
std::unordered_map<std::string, array> load_safetensors(
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
|
@ -3194,4 +3194,44 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): The dequantized version of ``w``
|
result (array): The dequantized version of ``w``
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"tensordot",
|
||||||
|
[](const array& a,
|
||||||
|
const array& b,
|
||||||
|
const std::variant<int, std::vector<std::vector<int>>>& dims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (auto pv = std::get_if<int>(&dims); pv) {
|
||||||
|
return tensordot(a, b, *pv, s);
|
||||||
|
} else {
|
||||||
|
auto x = std::get<std::vector<std::vector<int>>>(dims);
|
||||||
|
if (x.size() != 2) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tensordot] dims must be a list of two lists.");
|
||||||
|
}
|
||||||
|
return tensordot(a, b, {x[0], x[1]}, s);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"dims"_a = 2,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Compute the tensor dot product along the specified axes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array
|
||||||
|
b (array): Input array
|
||||||
|
dims (int or list(list(int)), optional): The number of dimensions to
|
||||||
|
sum over. If an integer is provided, then sum over the last
|
||||||
|
``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
|
||||||
|
``b``. If a list of lists is provided, then sum over the
|
||||||
|
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (array): The tensor dot product.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -1547,6 +1547,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected_3 = np.repeat(data_3, 2, axis=0)
|
expected_3 = np.repeat(data_3, 2, axis=0)
|
||||||
self.assertEqualArray(repeat_3, mx.array(expected_3))
|
self.assertEqualArray(repeat_3, mx.array(expected_3))
|
||||||
|
|
||||||
|
def test_tensordot(self):
|
||||||
|
x = mx.arange(60.0).reshape(3, 4, 5)
|
||||||
|
y = mx.arange(24.0).reshape(4, 3, 2)
|
||||||
|
z = mx.tensordot(x, y, dims=([1, 0], [0, 1]))
|
||||||
|
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1]))))
|
||||||
|
x = mx.random.normal((3, 4, 5))
|
||||||
|
y = mx.random.normal((4, 5, 6))
|
||||||
|
z = mx.tensordot(x, y, dims=2)
|
||||||
|
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2)))
|
||||||
|
x = mx.random.normal((3, 5, 4, 6))
|
||||||
|
y = mx.random.normal((6, 4, 5, 3))
|
||||||
|
z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0]))
|
||||||
|
self.assertEqualArray(
|
||||||
|
z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0])))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2278,3 +2278,40 @@ TEST_CASE("test repeat") {
|
|||||||
// negative repeats
|
// negative repeats
|
||||||
CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
|
CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("tensordot") {
|
||||||
|
auto x = reshape(arange(60.), {3, 4, 5});
|
||||||
|
auto y = reshape(arange(24.), {4, 3, 2});
|
||||||
|
auto z = tensordot(x, y, {{1, 0}, {0, 1}});
|
||||||
|
auto expected = array(
|
||||||
|
{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});
|
||||||
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
|
x = reshape(arange(360.), {3, 4, 5, 6});
|
||||||
|
y = reshape(arange(360.), {6, 4, 5, 3});
|
||||||
|
CHECK_THROWS_AS(
|
||||||
|
tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument);
|
||||||
|
x = reshape(arange(60.), {3, 4, 5});
|
||||||
|
y = reshape(arange(120.), {4, 5, 6});
|
||||||
|
z = tensordot(x, y, 2);
|
||||||
|
expected = array(
|
||||||
|
{14820.,
|
||||||
|
15010.,
|
||||||
|
15200.,
|
||||||
|
15390.,
|
||||||
|
15580.,
|
||||||
|
15770.,
|
||||||
|
37620.,
|
||||||
|
38210.,
|
||||||
|
38800.,
|
||||||
|
39390.,
|
||||||
|
39980.,
|
||||||
|
40570.,
|
||||||
|
60420.,
|
||||||
|
61410.,
|
||||||
|
62400.,
|
||||||
|
63390.,
|
||||||
|
64380.,
|
||||||
|
65370.},
|
||||||
|
{3, 6});
|
||||||
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user