diff --git a/mlx/einsum.cpp b/mlx/einsum.cpp index 2858f9110..cfc9fb2ad 100644 --- a/mlx/einsum.cpp +++ b/mlx/einsum.cpp @@ -71,6 +71,7 @@ std::pair, std::string> parse(std::string subscripts) { } else { // Implicit mode: // - repeats are summed + // - ellipses are placed in the beginning of the output // - remaining output axes are ordered alphabetically lhs = subscripts; std::unordered_map temp; @@ -78,6 +79,11 @@ std::pair, std::string> parse(std::string subscripts) { if (c == ',') { continue; } + if (c == '.' && rhs.empty()) { + rhs += "..."; + continue; + } + auto inserted = temp.insert({c, 0}); inserted.first->second++; } @@ -641,20 +647,83 @@ std::pair, PathInfo> einsum_path_helper( throw std::invalid_argument(msg.str()); } - auto check_letters = [&](const auto& subscript) { - for (auto c : subscript) { - if (!isalpha(c)) { - std::ostringstream msg; - msg << "[" << fn_name << "] Subscripts must be letters, but got '" << c - << "'."; - throw std::invalid_argument(msg.str()); + // Expand ellipses + // 1. Collect all the characters we can use for the missing axes. + // 2. Go over each subscript and check if all the characters are either + // alphanumeric or an ellipsis. + // 3. Expand the ellipsis with as many characters from the unused ones as + // necessary. We use the last N characters effectively prepending with + // singleton dims for inputs with fewer dimensions. + // 4. For the output use the maximum size of ellipsis that we encountered in + // the input. + CharSet used_chars(subscripts.begin(), subscripts.end()); + std::string remaining_chars; + remaining_chars.reserve(52 - used_chars.size()); + for (char c = 'a'; c <= 'z'; c++) { + if (used_chars.find(c) == used_chars.end()) { + remaining_chars += c; + } + } + for (char c = 'A'; c <= 'Z'; c++) { + if (used_chars.find(c) == used_chars.end()) { + remaining_chars += c; + } + } + int max_ellipsis_length = 0; + auto check_letters_and_expand_ellipsis = [&](auto& subscript, + const array* operand) { + bool have_ellipsis = false; + int cnt_before = 0, cnt_after = 0; + for (int i = 0; i < subscript.size(); i++) { + if (!isalpha(subscript[i])) { + if (i + 2 >= subscript.size() || subscript[i] != '.' || + subscript[i + 1] != '.' || subscript[i + 2] != '.') { + std::ostringstream msg; + msg << "[" << fn_name << "] Subscripts must be letters, but got '" + << subscript[i] << "'."; + throw std::invalid_argument(msg.str()); + } + + if (have_ellipsis) { + std::ostringstream msg; + msg << "[" << fn_name + << "] Only one ellipsis per subscript is allowed but found more in '" + << subscript << "'."; + throw std::invalid_argument(msg.str()); + } + + have_ellipsis = true; + i += 2; + continue; + } + + if (have_ellipsis) { + cnt_after++; + } else { + cnt_before++; } } + + if (have_ellipsis) { + int ellipsis_length; + if (operand != nullptr) { + ellipsis_length = operand->ndim() - cnt_before - cnt_after; + max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length); + } else { + ellipsis_length = max_ellipsis_length; + } + subscript.replace( + subscript.begin() + cnt_before, + subscript.begin() + cnt_before + 3, + remaining_chars.end() - ellipsis_length, + remaining_chars.end()); + } }; - for (auto& in : in_subscripts) { - check_letters(in); + + for (int i = 0; i < operands.size(); i++) { + check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i]); } - check_letters(out_subscript); + check_letters_and_expand_ellipsis(out_subscript, nullptr); CharSet out_set(out_subscript.begin(), out_subscript.end()); if (out_set.size() != out_subscript.size()) { diff --git a/python/tests/test_einsum.py b/python/tests/test_einsum.py index 919720c50..19ea8178e 100644 --- a/python/tests/test_einsum.py +++ b/python/tests/test_einsum.py @@ -313,6 +313,51 @@ class TestEinsum(mlx_tests.MLXTestCase): mx_out = mx.einsum(test_case, *inputs) self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4)) + def test_ellipses(self): + size_dict = dict(zip("abcdefghij", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3])) + + def inputs_for_case(test_case): + inputs = test_case.split("->")[0].split(",") + return [ + mx.random.uniform(shape=tuple(size_dict[c] for c in inp)) + for inp in inputs + ] + + tests = [ + ("abc->ab", "...c->..."), + ("abcd->ad", "a...d->..."), + ("abij,abgj->abig", "...ij,...gj->...ig"), + ("abij,abgj->abig", "...ij,...gj->..."), + ("abhh->abh", "...hh->...h"), + ("abhh->abh", "...hh->...h"), + ("bch,abcj->abchj", "...h,...j->...hj"), + ("bc,cd->bd", "...c,cd"), + ("abc,acd->bd", "...bc,...cd"), + ("abcd,c->abd", "...cd,c"), + ("abcd,c->abd", "...cd,c..."), + ("abcd,c->abd", "...cd,c...->d..."), + ("abc,b->abc", "ab...,b...->ab..."), + ("abc,b->abc", "ab...,...b->ab..."), + ("abc,b->abc", "ab...,b->ab..."), + ("ab,bc->ac", "ab...,b...->a..."), + ("ab,bc->ac", "ab...,...bc->a...c"), + ("ab,bc->ac", "ab,b...->a..."), + ("abcdef,defg->abcg", "...def,defg->...g"), + ] + for test_case in tests: + inputs = inputs_for_case(test_case[0]) + np_out = np.einsum(test_case[1], *inputs) + mx_out = mx.einsum(test_case[1], *inputs) + self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4)) + + error_tests = [ + ("abc,abc->ab", "a...b...c,a...b...c->abc"), + ] + for test_case in error_tests: + inputs = inputs_for_case(test_case[0]) + with self.assertRaises(ValueError): + mx.einsum(test_case[1], *inputs) + if __name__ == "__main__": unittest.main()