mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Einsum ellipsis (#1788)
This commit is contained in:
parent
e6a7ab9675
commit
72146fc4cd
@ -71,6 +71,7 @@ std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
|
|||||||
} else {
|
} else {
|
||||||
// Implicit mode:
|
// Implicit mode:
|
||||||
// - repeats are summed
|
// - repeats are summed
|
||||||
|
// - ellipses are placed in the beginning of the output
|
||||||
// - remaining output axes are ordered alphabetically
|
// - remaining output axes are ordered alphabetically
|
||||||
lhs = subscripts;
|
lhs = subscripts;
|
||||||
std::unordered_map<char, int> temp;
|
std::unordered_map<char, int> temp;
|
||||||
@ -78,6 +79,11 @@ std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
|
|||||||
if (c == ',') {
|
if (c == ',') {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (c == '.' && rhs.empty()) {
|
||||||
|
rhs += "...";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
auto inserted = temp.insert({c, 0});
|
auto inserted = temp.insert({c, 0});
|
||||||
inserted.first->second++;
|
inserted.first->second++;
|
||||||
}
|
}
|
||||||
@ -641,20 +647,83 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto check_letters = [&](const auto& subscript) {
|
// Expand ellipses
|
||||||
for (auto c : subscript) {
|
// 1. Collect all the characters we can use for the missing axes.
|
||||||
if (!isalpha(c)) {
|
// 2. Go over each subscript and check if all the characters are either
|
||||||
std::ostringstream msg;
|
// alphanumeric or an ellipsis.
|
||||||
msg << "[" << fn_name << "] Subscripts must be letters, but got '" << c
|
// 3. Expand the ellipsis with as many characters from the unused ones as
|
||||||
<< "'.";
|
// necessary. We use the last N characters effectively prepending with
|
||||||
throw std::invalid_argument(msg.str());
|
// singleton dims for inputs with fewer dimensions.
|
||||||
|
// 4. For the output use the maximum size of ellipsis that we encountered in
|
||||||
|
// the input.
|
||||||
|
CharSet used_chars(subscripts.begin(), subscripts.end());
|
||||||
|
std::string remaining_chars;
|
||||||
|
remaining_chars.reserve(52 - used_chars.size());
|
||||||
|
for (char c = 'a'; c <= 'z'; c++) {
|
||||||
|
if (used_chars.find(c) == used_chars.end()) {
|
||||||
|
remaining_chars += c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (char c = 'A'; c <= 'Z'; c++) {
|
||||||
|
if (used_chars.find(c) == used_chars.end()) {
|
||||||
|
remaining_chars += c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int max_ellipsis_length = 0;
|
||||||
|
auto check_letters_and_expand_ellipsis = [&](auto& subscript,
|
||||||
|
const array* operand) {
|
||||||
|
bool have_ellipsis = false;
|
||||||
|
int cnt_before = 0, cnt_after = 0;
|
||||||
|
for (int i = 0; i < subscript.size(); i++) {
|
||||||
|
if (!isalpha(subscript[i])) {
|
||||||
|
if (i + 2 >= subscript.size() || subscript[i] != '.' ||
|
||||||
|
subscript[i + 1] != '.' || subscript[i + 2] != '.') {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name << "] Subscripts must be letters, but got '"
|
||||||
|
<< subscript[i] << "'.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (have_ellipsis) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name
|
||||||
|
<< "] Only one ellipsis per subscript is allowed but found more in '"
|
||||||
|
<< subscript << "'.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
have_ellipsis = true;
|
||||||
|
i += 2;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (have_ellipsis) {
|
||||||
|
cnt_after++;
|
||||||
|
} else {
|
||||||
|
cnt_before++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (have_ellipsis) {
|
||||||
|
int ellipsis_length;
|
||||||
|
if (operand != nullptr) {
|
||||||
|
ellipsis_length = operand->ndim() - cnt_before - cnt_after;
|
||||||
|
max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length);
|
||||||
|
} else {
|
||||||
|
ellipsis_length = max_ellipsis_length;
|
||||||
|
}
|
||||||
|
subscript.replace(
|
||||||
|
subscript.begin() + cnt_before,
|
||||||
|
subscript.begin() + cnt_before + 3,
|
||||||
|
remaining_chars.end() - ellipsis_length,
|
||||||
|
remaining_chars.end());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
for (auto& in : in_subscripts) {
|
|
||||||
check_letters(in);
|
for (int i = 0; i < operands.size(); i++) {
|
||||||
|
check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i]);
|
||||||
}
|
}
|
||||||
check_letters(out_subscript);
|
check_letters_and_expand_ellipsis(out_subscript, nullptr);
|
||||||
|
|
||||||
CharSet out_set(out_subscript.begin(), out_subscript.end());
|
CharSet out_set(out_subscript.begin(), out_subscript.end());
|
||||||
if (out_set.size() != out_subscript.size()) {
|
if (out_set.size() != out_subscript.size()) {
|
||||||
|
@ -313,6 +313,51 @@ class TestEinsum(mlx_tests.MLXTestCase):
|
|||||||
mx_out = mx.einsum(test_case, *inputs)
|
mx_out = mx.einsum(test_case, *inputs)
|
||||||
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
|
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
|
||||||
|
|
||||||
|
def test_ellipses(self):
|
||||||
|
size_dict = dict(zip("abcdefghij", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3]))
|
||||||
|
|
||||||
|
def inputs_for_case(test_case):
|
||||||
|
inputs = test_case.split("->")[0].split(",")
|
||||||
|
return [
|
||||||
|
mx.random.uniform(shape=tuple(size_dict[c] for c in inp))
|
||||||
|
for inp in inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
("abc->ab", "...c->..."),
|
||||||
|
("abcd->ad", "a...d->..."),
|
||||||
|
("abij,abgj->abig", "...ij,...gj->...ig"),
|
||||||
|
("abij,abgj->abig", "...ij,...gj->..."),
|
||||||
|
("abhh->abh", "...hh->...h"),
|
||||||
|
("abhh->abh", "...hh->...h"),
|
||||||
|
("bch,abcj->abchj", "...h,...j->...hj"),
|
||||||
|
("bc,cd->bd", "...c,cd"),
|
||||||
|
("abc,acd->bd", "...bc,...cd"),
|
||||||
|
("abcd,c->abd", "...cd,c"),
|
||||||
|
("abcd,c->abd", "...cd,c..."),
|
||||||
|
("abcd,c->abd", "...cd,c...->d..."),
|
||||||
|
("abc,b->abc", "ab...,b...->ab..."),
|
||||||
|
("abc,b->abc", "ab...,...b->ab..."),
|
||||||
|
("abc,b->abc", "ab...,b->ab..."),
|
||||||
|
("ab,bc->ac", "ab...,b...->a..."),
|
||||||
|
("ab,bc->ac", "ab...,...bc->a...c"),
|
||||||
|
("ab,bc->ac", "ab,b...->a..."),
|
||||||
|
("abcdef,defg->abcg", "...def,defg->...g"),
|
||||||
|
]
|
||||||
|
for test_case in tests:
|
||||||
|
inputs = inputs_for_case(test_case[0])
|
||||||
|
np_out = np.einsum(test_case[1], *inputs)
|
||||||
|
mx_out = mx.einsum(test_case[1], *inputs)
|
||||||
|
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
|
||||||
|
|
||||||
|
error_tests = [
|
||||||
|
("abc,abc->ab", "a...b...c,a...b...c->abc"),
|
||||||
|
]
|
||||||
|
for test_case in error_tests:
|
||||||
|
inputs = inputs_for_case(test_case[0])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.einsum(test_case[1], *inputs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user