Einsum ellipsis (#1788)

This commit is contained in:
Angelos Katharopoulos
2025-01-25 01:28:03 -08:00
committed by GitHub
parent e6a7ab9675
commit 72146fc4cd
2 changed files with 124 additions and 10 deletions

View File

@@ -71,6 +71,7 @@ std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
} else {
// Implicit mode:
// - repeats are summed
// - ellipses are placed in the beginning of the output
// - remaining output axes are ordered alphabetically
lhs = subscripts;
std::unordered_map<char, int> temp;
@@ -78,6 +79,11 @@ std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
if (c == ',') {
continue;
}
if (c == '.' && rhs.empty()) {
rhs += "...";
continue;
}
auto inserted = temp.insert({c, 0});
inserted.first->second++;
}
@@ -641,20 +647,83 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
throw std::invalid_argument(msg.str());
}
auto check_letters = [&](const auto& subscript) {
for (auto c : subscript) {
if (!isalpha(c)) {
std::ostringstream msg;
msg << "[" << fn_name << "] Subscripts must be letters, but got '" << c
<< "'.";
throw std::invalid_argument(msg.str());
// Expand ellipses
// 1. Collect all the characters we can use for the missing axes.
// 2. Go over each subscript and check if all the characters are either
// alphanumeric or an ellipsis.
// 3. Expand the ellipsis with as many characters from the unused ones as
// necessary. We use the last N characters effectively prepending with
// singleton dims for inputs with fewer dimensions.
// 4. For the output use the maximum size of ellipsis that we encountered in
// the input.
CharSet used_chars(subscripts.begin(), subscripts.end());
std::string remaining_chars;
remaining_chars.reserve(52 - used_chars.size());
for (char c = 'a'; c <= 'z'; c++) {
if (used_chars.find(c) == used_chars.end()) {
remaining_chars += c;
}
}
for (char c = 'A'; c <= 'Z'; c++) {
if (used_chars.find(c) == used_chars.end()) {
remaining_chars += c;
}
}
int max_ellipsis_length = 0;
auto check_letters_and_expand_ellipsis = [&](auto& subscript,
const array* operand) {
bool have_ellipsis = false;
int cnt_before = 0, cnt_after = 0;
for (int i = 0; i < subscript.size(); i++) {
if (!isalpha(subscript[i])) {
if (i + 2 >= subscript.size() || subscript[i] != '.' ||
subscript[i + 1] != '.' || subscript[i + 2] != '.') {
std::ostringstream msg;
msg << "[" << fn_name << "] Subscripts must be letters, but got '"
<< subscript[i] << "'.";
throw std::invalid_argument(msg.str());
}
if (have_ellipsis) {
std::ostringstream msg;
msg << "[" << fn_name
<< "] Only one ellipsis per subscript is allowed but found more in '"
<< subscript << "'.";
throw std::invalid_argument(msg.str());
}
have_ellipsis = true;
i += 2;
continue;
}
if (have_ellipsis) {
cnt_after++;
} else {
cnt_before++;
}
}
if (have_ellipsis) {
int ellipsis_length;
if (operand != nullptr) {
ellipsis_length = operand->ndim() - cnt_before - cnt_after;
max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length);
} else {
ellipsis_length = max_ellipsis_length;
}
subscript.replace(
subscript.begin() + cnt_before,
subscript.begin() + cnt_before + 3,
remaining_chars.end() - ellipsis_length,
remaining_chars.end());
}
};
for (auto& in : in_subscripts) {
check_letters(in);
for (int i = 0; i < operands.size(); i++) {
check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i]);
}
check_letters(out_subscript);
check_letters_and_expand_ellipsis(out_subscript, nullptr);
CharSet out_set(out_subscript.begin(), out_subscript.end());
if (out_set.size() != out_subscript.size()) {