initial commit

This commit is contained in:
dc-dc-dc 2024-01-02 12:43:53 -05:00
parent 295ce9db09
commit 8ded7c8d37
3 changed files with 161 additions and 0 deletions

View File

@ -2793,4 +2793,104 @@ array dequantize(
return w_full; return w_full;
} }
array tensordot(
const array& a,
const array& b,
const int dims /* = 2 */,
StreamOrDevice s /* = {} */
) {
if (dims < 0) {
throw std::invalid_argument(
"[tensordot] dims must be greater or equal to 0.");
}
if (dims > std::min(a.ndim(), b.ndim())) {
throw std::invalid_argument(
"[tensordot] dims must be less than the number of dimensions of a and b.");
}
std::vector<int> adims;
std::vector<int> bdims;
for (int i = 0; i < dims; i++) {
bdims.emplace_back(i);
adims.emplace_back(-dims + i);
}
return tensordot(a, b, {adims, bdims}, s);
}
array tensordot(
const array& a,
const array& b,
const std::vector<std::vector<int>>& dims,
StreamOrDevice s /* = {} */
) {
if (dims.size() != 2) {
throw std::invalid_argument(
"[tensordot] dims must be a vector of two vectors.");
}
if (dims[0].size() != dims[1].size()) {
throw std::invalid_argument(
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
}
if (a.dtype() != b.dtype()) {
throw std::invalid_argument(
"[tensordot] a and b must have the same dtype.");
}
int csize = 1;
auto x = a;
auto y = b;
for (int i = 0; i < dims[0].size(); i++) {
size_t xs = x.shape(dims[0].at(i));
size_t ys = y.shape(dims[1].at(i));
if (ys == 1) {
x = sum(x, dims[0].at(i), true, s);
} else if (xs == 1) {
y = sum(y, dims[1].at(i), true, s);
} else {
csize *= xs;
}
}
std::vector<bool> cdims1(x.ndim(), false);
std::vector<bool> cdims2(y.ndim(), false);
for (const auto n : dims[0]) {
int n_ = (n < 0) ? n + x.ndim() : n;
cdims1[n_] = true;
}
for (const auto n : dims[1]) {
int n_ = (n < 0) ? n + y.ndim() : n;
cdims2[n_] = true;
}
std::vector<int> t1;
t1.reserve(a.ndim());
std::vector<int> t2;
t2.reserve(b.ndim());
std::vector<int> rshape;
rshape.reserve(a.ndim() + b.ndim() * dims[0].size());
int size1 = 1;
int size2 = 1;
for (int i = 0; i < a.ndim(); i++) {
if (!cdims1[i]) {
t1.emplace_back(i);
size1 *= a.shape(i);
rshape.emplace_back(a.shape(i));
}
}
for (const auto x : dims[0]) {
t1.emplace_back(x);
}
for (const auto x : dims[1]) {
t2.emplace_back(x);
}
for (int i = 0; i < b.ndim(); i++) {
if (!cdims2[i]) {
t2.emplace_back(i);
size2 *= b.shape(i);
rshape.emplace_back(b.shape(i));
}
}
x = reshape(transpose(x, t1, s), {size1, csize}, s);
y = reshape(transpose(y, t2, s), {csize, size2}, s);
return reshape(matmul(x, y, s), rshape, s);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1061,6 +1061,19 @@ array dequantize(
int bits = 4, int bits = 4,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** TensorDot returns a contraction of a and b over multiple dimensions. */
array tensordot(
const array& a,
const array& b,
const int dims = 2,
StreamOrDevice s = {});
array tensordot(
const array& a,
const array& b,
const std::vector<std::vector<int>>& dims,
StreamOrDevice s = {});
/** Load array map from .safetensors file format */ /** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensors( std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream, std::shared_ptr<io::Reader> in_stream,

View File

@ -2278,3 +2278,51 @@ TEST_CASE("test repeat") {
// negative repeats // negative repeats
CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument); CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
} }
TEST_CASE("tensordot") {
auto x = reshape(arange(60.), {3, 4, 5});
auto y = reshape(arange(24.), {4, 3, 2});
auto z = tensordot(x, y, {{1, 0}, {0, 1}});
auto expected = array(
{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});
CHECK(array_equal(z, expected).item<bool>());
x = reshape(arange(360.), {3, 4, 5, 6});
y = reshape(arange(360.), {6, 4, 5, 3});
z = tensordot(x, y, {{2, 1, 3}, {1, 2, 0}});
expected = array(
{1326270,
1333410,
1340550,
3896670,
3918210,
3939750,
6467070,
6503010,
6538950},
{3, 3});
CHECK(array_equal(z, expected).item<bool>());
x = reshape(arange(60.), {3, 4, 5});
y = reshape(arange(120.), {4, 5, 6});
z = tensordot(x, y, 2);
expected = array(
{14820.,
15010.,
15200.,
15390.,
15580.,
15770.,
37620.,
38210.,
38800.,
39390.,
39980.,
40570.,
60420.,
61410.,
62400.,
63390.,
64380.,
65370.},
{3, 6});
CHECK(array_equal(z, expected).item<bool>());
}