mlx/mlx/einsum.cpp
2025-01-25 01:28:03 -08:00

930 lines
27 KiB
C++

// Copyright © 2024 Apple Inc.
#include <numeric>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include "mlx/einsum.h"
#include "mlx/ops.h"
namespace mlx::core {
namespace {
// The MLX einsum implementation is based on NumPy (which is based on
// opt_einsum):
// https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743
// https://github.com/dgasmith/opt_einsum
using CharSet = std::unordered_set<char>;
// A helper struct to hold the string and set
// representation of a subscript to avoid needing
// to recompute the set
struct Subscript {
Subscript(std::string str, CharSet set)
: str(std::move(str)), set(std::move(set)) {};
std::string str;
CharSet set;
};
struct PathInfo {
size_t naive_cost;
size_t naive_scaling;
size_t optimized_cost;
size_t optimized_scaling;
size_t largest_term;
};
struct PathNode {
PathNode(
std::vector<Subscript> inputs,
Subscript output,
std::vector<int> positions)
: inputs(std::move(inputs)),
output(std::move(output)),
positions(std::move(positions)) {};
std::vector<Subscript> inputs;
Subscript output;
std::vector<int> positions;
};
// Parse the comma separated subscripts into a vector of strings. If the
// output subscripts are missing they are inferred.
//
// For example:
// "ij,jk -> ik" becomes {{"ij", "jk"}, "ik"}
// "ij,jk" becomes {{"ij", "jk"}, "ik"}
std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
std::string lhs, rhs;
// Start by removing all white space
subscripts.erase(
std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end());
if (auto pos = subscripts.find("->"); pos != std::string::npos) {
// Explicit mode
lhs = subscripts.substr(0, pos);
rhs = subscripts.substr(pos + 2);
} else {
// Implicit mode:
// - repeats are summed
// - ellipses are placed in the beginning of the output
// - remaining output axes are ordered alphabetically
lhs = subscripts;
std::unordered_map<char, int> temp;
for (auto& c : subscripts) {
if (c == ',') {
continue;
}
if (c == '.' && rhs.empty()) {
rhs += "...";
continue;
}
auto inserted = temp.insert({c, 0});
inserted.first->second++;
}
for (auto& k : temp) {
if (k.second == 1) {
rhs += k.first;
}
}
std::sort(rhs.begin(), rhs.end());
}
std::vector<std::string> input_list;
std::stringstream ss(lhs);
std::string token;
while (getline(ss, token, ',')) {
input_list.push_back(token);
}
return {input_list, rhs};
}
// Check if two sets are disjoint
bool disjoint(const CharSet& x, const CharSet& y) {
for (auto& c : x) {
if (y.find(c) != y.end()) {
return false;
}
}
return true;
}
template <typename T>
size_t term_size(const T& term, std::unordered_map<char, ShapeElem> dict) {
size_t size = 1;
for (auto c : term) {
size *= dict[c];
}
return size;
}
size_t flop_count(
const CharSet& term,
bool inner,
int num_terms,
std::unordered_map<char, ShapeElem> dict) {
size_t size = term_size(term, dict);
auto op_factor = 1;
if ((num_terms - 1) > op_factor) {
op_factor = num_terms - 1;
}
if (inner) {
op_factor += 1;
}
return size * op_factor;
}
std::pair<size_t, int> compute_cost_and_scaling(
const std::vector<Subscript>& inputs,
const Subscript& output,
std::unordered_map<char, ShapeElem> dim_map) {
CharSet contractions;
for (auto& in : inputs) {
contractions.insert(in.set.begin(), in.set.end());
}
bool inner = false;
for (auto c : contractions) {
if (output.set.find(c) == output.set.end()) {
inner = true;
break;
}
}
auto cost = flop_count(contractions, inner, inputs.size(), dim_map);
return {cost, contractions.size()};
}
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
std::vector<Subscript> inputs,
const Subscript& output,
std::unordered_map<char, ShapeElem> dim_map,
size_t cost_limit,
size_t memory_limit) {
// Helper struct for building the greedy path
struct Contraction {
Contraction(
size_t size,
size_t cost,
CharSet output,
int dims,
int x,
int y)
: size(size),
cost(cost),
output(std::move(output)),
dims(dims),
x(x),
y(y) {};
int64_t size; // Size difference, can be negative
size_t cost;
CharSet output;
int dims; // Number of dimensions in the contraction
int x;
int y;
};
// Start by iterating over all possible combinations
std::vector<std::pair<int, int>> pos_pairs;
for (int i = 0; i < inputs.size(); ++i) {
for (int j = i + 1; j < inputs.size(); ++j) {
pos_pairs.emplace_back(i, j);
}
}
std::vector<PathNode> path;
std::vector<Contraction> possible_contractions;
size_t path_cost = 0;
int path_scaling = 0;
auto num_in = inputs.size();
for (int i = 0; i < num_in - 1; ++i) {
auto add_contraction = [&](int p1, int p2) {
CharSet new_term;
CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end());
contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end());
for (int i = 0; i < inputs.size(); i++) {
if (i == p1 || i == p2) {
continue;
}
auto& in = inputs[i].set;
for (auto c : in) {
if (contractions.find(c) != contractions.end()) {
new_term.insert(c);
}
}
}
for (auto c : output.set) {
if (contractions.find(c) != contractions.end()) {
new_term.insert(c);
}
}
// Ignore if:
// - The size of the new result is greater than the memory limit
// - The cost is larger than the naive cost
auto new_size = term_size(new_term, dim_map);
if (new_size > memory_limit) {
return;
}
int64_t removed_size = term_size(inputs[p1].set, dim_map) +
term_size(inputs[p2].set, dim_map) - new_size;
bool inner = contractions.size() > new_term.size();
auto cost = flop_count(contractions, inner, 2, dim_map);
if (path_cost + cost > cost_limit) {
return;
}
possible_contractions.emplace_back(
removed_size, cost, std::move(new_term), contractions.size(), p1, p2);
};
for (auto& [p1, p2] : pos_pairs) {
// Ignore outer products
if (!disjoint(inputs[p1].set, inputs[p2].set)) {
add_contraction(p1, p2);
}
}
// If there's nothing in the contraction list,
// go over the pairs again without ignoring outer products
if (possible_contractions.empty()) {
for (auto& [p1, p2] : pos_pairs) {
add_contraction(p1, p2);
}
}
if (possible_contractions.empty()) {
// Default to naive einsum for the remaining inputs
std::vector<int> positions(inputs.size());
std::iota(positions.begin(), positions.end(), 0);
auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map);
path.emplace_back(std::move(inputs), output, std::move(positions));
path_cost += cost;
path_scaling = std::max(scale, path_scaling);
break;
}
// Find the best contraction
auto& best = *std::min_element(
possible_contractions.begin(),
possible_contractions.end(),
[](const auto& x, const auto& y) {
return x.size > y.size || (x.size == y.size && x.cost < y.cost);
});
path_scaling = std::max(best.dims, path_scaling);
// Construct the output subscripts
std::string out_str(best.output.begin(), best.output.end());
// TODO, sorting by dimension size seems suboptimal?
std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) {
return dim_map[x] < dim_map[y];
});
Subscript new_output(std::move(out_str), std::move(best.output));
// Add the chosen contraction to the path
{
std::vector<Subscript> in_terms;
in_terms.push_back(std::move(inputs[best.x]));
in_terms.push_back(std::move(inputs[best.y]));
path.emplace_back(
std::move(in_terms), new_output, std::vector<int>{best.x, best.y});
}
// Remove used terms
inputs.erase(inputs.begin() + best.y);
inputs.erase(inputs.begin() + best.x);
// Add the new result
inputs.push_back(std::move(new_output));
// Update the existing contractions based on the selected one
std::vector<Contraction> updated_contractions;
for (auto& contraction : possible_contractions) {
// Drop contractions which contain either selected term
if (contraction.x == best.x || contraction.x == best.y ||
contraction.y == best.x || contraction.y == best.y) {
continue;
}
// Update the positions of other contractions
int x =
contraction.x - (contraction.x > best.x) - (contraction.x > best.y);
int y =
contraction.y - (contraction.y > best.x) - (contraction.y > best.y);
contraction.x = x;
contraction.y = y;
updated_contractions.push_back(std::move(contraction));
}
pos_pairs.clear();
for (int i = 0; i < inputs.size() - 1; ++i) {
pos_pairs.emplace_back(i, inputs.size() - 1);
}
path_cost += best.cost;
possible_contractions = std::move(updated_contractions);
}
return {path, path_cost, path_scaling};
}
// Assumes inputs have already have had repeats and single axis sums collapsed
bool can_dot(const std::vector<Subscript>& inputs, const Subscript& output) {
if (inputs.size() != 2) {
return false;
}
for (auto c : inputs[0].set) {
// Use batched tensordot if anything is being contracted
if (output.set.find(c) == output.set.end()) {
return true;
}
}
return false;
}
array batch_tensordot(
array a,
array b,
std::vector<int> a_contract,
std::vector<int> a_batch,
std::vector<int> a_concat,
std::vector<int> b_contract,
std::vector<int> b_batch,
std::vector<int> b_concat,
StreamOrDevice s) {
// Broadcast contracting dimensions
{
auto a_shape = a.shape();
auto b_shape = b.shape();
for (int i = 0; i < a_contract.size(); ++i) {
auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i]));
a_shape[a_contract[i]] = d;
b_shape[b_contract[i]] = d;
}
a = broadcast_to(a, a_shape, s);
b = broadcast_to(b, b_shape, s);
}
auto transpose_reshape = [&s](
const array& x,
const std::vector<int>& i,
const std::vector<int>& j,
const std::vector<int>& k) {
std::vector<int> reorder(i.begin(), i.end());
reorder.insert(reorder.end(), j.begin(), j.end());
reorder.insert(reorder.end(), k.begin(), k.end());
int size1 = 1;
for (auto s : j) {
size1 *= x.shape(s);
}
int size2 = 1;
for (auto s : k) {
size2 *= x.shape(s);
}
Shape shape;
for (auto ax : i) {
shape.push_back(x.shape(ax));
}
shape.push_back(size1);
shape.push_back(size2);
return reshape(transpose(x, reorder, s), std::move(shape), s);
};
Shape out_shape;
for (auto ax : a_batch) {
out_shape.push_back(a.shape(ax));
}
for (auto ax : a_concat) {
out_shape.push_back(a.shape(ax));
}
for (auto ax : b_concat) {
out_shape.push_back(b.shape(ax));
}
a = transpose_reshape(a, a_batch, a_concat, a_contract);
b = transpose_reshape(b, b_batch, b_contract, b_concat);
return reshape(matmul(a, b, s), std::move(out_shape), s);
}
// Collapse repeated subscripts and return the resulting array. The subscript
// is also updated in place. For example:
// - Given an input with shape (4, 4) and subscript "ii", returns
// the diagonal of shape (4,) and updates the subscript to "i".
// - Given an input with shape (4, 2, 4, 2) and subscript "ijij",
// returns an output with shape (4, 2) and updates the subscript
// to "ij".
array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
// Build a list of (repeat chars, num repeats)
auto& str = subscript.str;
std::vector<std::pair<char, int>> repeats;
std::string new_str;
{
std::string repeat_str;
std::string no_repeat_str;
std::unordered_map<char, int> counts;
for (int i = 0; i < str.size(); ++i) {
auto [it, _] = counts.insert({str[i], 0});
it->second++;
}
for (auto& v : counts) {
if (v.second > 1) {
repeats.emplace_back(v.first, v.second);
repeat_str += v.first;
}
}
for (auto& c : str) {
if (counts[c] == 1) {
no_repeat_str += c;
}
}
new_str = repeat_str + no_repeat_str;
}
// Build the inputs for gather
auto slice_sizes = in.shape();
std::vector<int> axes;
std::vector<array> indices;
int n_expand = repeats.size();
for (auto [c, v] : repeats) {
for (int i = 0; i < str.size(); ++i) {
if (str[i] == c) {
slice_sizes[i] = 1;
axes.push_back(i);
}
}
Shape idx_shape(n_expand--, 1);
idx_shape[0] = in.shape(axes.back());
auto idx = reshape(
arange(static_cast<ShapeElem>(in.shape(axes.back())), s), idx_shape, s);
for (int i = 0; i < v; ++i) {
indices.push_back(idx);
}
}
in = gather(in, indices, axes, slice_sizes, s);
// Update subscript string with removed dups
str = new_str;
// Squeeze singleton dimensions left over from the gather
for (auto& ax : axes) {
ax += indices[0].ndim();
}
return squeeze(in, axes, s);
}
// Collapse repeat indices and sum single dimensions.
// For example:
// - "aa" becomes "a"
// - "ij,jk->k" becoms "j,jk->k"
void preprocess_einsum_inputs(
std::vector<Subscript>& inputs,
const Subscript& output,
const std::vector<int>& positions,
std::vector<array>& operands,
StreamOrDevice s) {
// Collapse repeat indices
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
if (in.set.size() < in.str.size()) {
operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s);
}
}
// Sum indices that are only in a single input
{
std::unordered_map<char, int> counts;
for (auto& in : inputs) {
for (auto c : in.set) {
auto inserted = counts.insert({c, 0});
inserted.first->second++;
}
}
for (auto c : output.set) {
auto inserted = counts.insert({c, 0});
inserted.first->second++;
}
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
std::vector<int> sum_axes;
for (int ax = 0; ax < in.str.size(); ++ax) {
if (counts[in.str[ax]] == 1) {
sum_axes.push_back(ax);
}
}
if (!sum_axes.empty()) {
operands[positions[i]] =
sum(operands[positions[i]], sum_axes, false, s);
}
for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) {
in.set.erase(in.str[*it]);
in.str.erase(in.str.begin() + *it);
}
}
}
}
array einsum_naive(
std::vector<Subscript> inputs,
const Subscript& output,
const std::vector<int>& positions,
std::vector<array> operands,
StreamOrDevice s) {
// Map each character to an axis
std::unordered_map<char, int> char_to_ax;
for (auto& in : inputs) {
for (auto c : in.str) {
char_to_ax.insert({c, char_to_ax.size()});
}
}
// Expand and transpose inputs as needed
for (int i = 0; i < inputs.size(); ++i) {
int pos = positions[i];
auto& op = operands[pos];
// Add missing dimensions at the end
if (op.ndim() != char_to_ax.size()) {
auto shape = op.shape();
shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1);
op = reshape(op, std::move(shape), s);
}
// Transpose:
// - Build a vector of (char, ax) pairs for the current input
// - Sort the vector by the canonical axis in char_to_ax
// - Extract the sorted axis to get transpose order
std::vector<std::pair<char, int>> str_ax;
for (auto c : inputs[i].str) {
str_ax.emplace_back(c, str_ax.size());
}
for (auto [c, ax] : char_to_ax) {
if (inputs[i].set.find(c) == inputs[i].set.end()) {
str_ax.emplace_back(c, str_ax.size());
}
}
std::sort(
str_ax.begin(),
str_ax.end(),
[&char_to_ax](const auto& x, const auto& y) {
return char_to_ax[x.first] < char_to_ax[y.first];
});
// Skip the transpose if not needed
if (std::is_sorted(
str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) {
return x.second < y.second;
})) {
continue;
}
std::vector<int> reorder;
for (auto [c, ax] : str_ax) {
reorder.push_back(ax);
}
op = transpose(op, reorder, s);
}
// Multiply and sum
auto out = operands[positions[0]];
for (int i = 1; i < positions.size(); ++i) {
out = multiply(out, operands[positions[i]], s);
}
std::vector<int> sum_axes;
for (auto [c, ax] : char_to_ax) {
if (output.set.find(c) == output.set.end()) {
sum_axes.push_back(ax);
}
}
if (!sum_axes.empty()) {
out = sum(out, sum_axes, false, s);
}
// Transpose output if needed
std::vector<int> reorder;
for (auto c : output.str) {
reorder.push_back(char_to_ax[c]);
}
for (auto& r : reorder) {
int offset = 0;
for (auto s : sum_axes) {
if (r > s) {
offset++;
}
}
r -= offset;
}
return transpose(out, reorder, s);
}
std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
const std::string& subscripts,
const std::vector<array>& operands,
const std::string& fn_name) {
if (operands.size() == 0) {
std::ostringstream msg;
msg << "[" << fn_name << "] At least one operand is required.";
throw std::invalid_argument(msg.str());
}
auto [in_subscripts, out_subscript] = parse(subscripts);
if (operands.size() != in_subscripts.size()) {
std::ostringstream msg;
msg << "[" << fn_name << "] Number of operands, " << operands.size()
<< ", does not match number of input subscripts, "
<< in_subscripts.size();
throw std::invalid_argument(msg.str());
}
// Expand ellipses
// 1. Collect all the characters we can use for the missing axes.
// 2. Go over each subscript and check if all the characters are either
// alphanumeric or an ellipsis.
// 3. Expand the ellipsis with as many characters from the unused ones as
// necessary. We use the last N characters effectively prepending with
// singleton dims for inputs with fewer dimensions.
// 4. For the output use the maximum size of ellipsis that we encountered in
// the input.
CharSet used_chars(subscripts.begin(), subscripts.end());
std::string remaining_chars;
remaining_chars.reserve(52 - used_chars.size());
for (char c = 'a'; c <= 'z'; c++) {
if (used_chars.find(c) == used_chars.end()) {
remaining_chars += c;
}
}
for (char c = 'A'; c <= 'Z'; c++) {
if (used_chars.find(c) == used_chars.end()) {
remaining_chars += c;
}
}
int max_ellipsis_length = 0;
auto check_letters_and_expand_ellipsis = [&](auto& subscript,
const array* operand) {
bool have_ellipsis = false;
int cnt_before = 0, cnt_after = 0;
for (int i = 0; i < subscript.size(); i++) {
if (!isalpha(subscript[i])) {
if (i + 2 >= subscript.size() || subscript[i] != '.' ||
subscript[i + 1] != '.' || subscript[i + 2] != '.') {
std::ostringstream msg;
msg << "[" << fn_name << "] Subscripts must be letters, but got '"
<< subscript[i] << "'.";
throw std::invalid_argument(msg.str());
}
if (have_ellipsis) {
std::ostringstream msg;
msg << "[" << fn_name
<< "] Only one ellipsis per subscript is allowed but found more in '"
<< subscript << "'.";
throw std::invalid_argument(msg.str());
}
have_ellipsis = true;
i += 2;
continue;
}
if (have_ellipsis) {
cnt_after++;
} else {
cnt_before++;
}
}
if (have_ellipsis) {
int ellipsis_length;
if (operand != nullptr) {
ellipsis_length = operand->ndim() - cnt_before - cnt_after;
max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length);
} else {
ellipsis_length = max_ellipsis_length;
}
subscript.replace(
subscript.begin() + cnt_before,
subscript.begin() + cnt_before + 3,
remaining_chars.end() - ellipsis_length,
remaining_chars.end());
}
};
for (int i = 0; i < operands.size(); i++) {
check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i]);
}
check_letters_and_expand_ellipsis(out_subscript, nullptr);
CharSet out_set(out_subscript.begin(), out_subscript.end());
if (out_set.size() != out_subscript.size()) {
std::ostringstream msg;
msg << "[" << fn_name << "] Repeat indices not allowed in output.";
throw std::invalid_argument(msg.str());
}
Subscript output(out_subscript, std::move(out_set));
std::unordered_map<char, ShapeElem> dim_map;
std::vector<Subscript> inputs;
for (int i = 0; i < in_subscripts.size(); ++i) {
auto& in = in_subscripts[i];
CharSet in_set(in.begin(), in.end());
inputs.emplace_back(in, in_set);
if (in.size() != operands[i].ndim()) {
std::ostringstream msg;
msg << "[" << fn_name << "] Invalid number of subscripts " << in.size()
<< " for input " << i << " with " << operands[i].ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// Check repeat subscripts are valid
if (in_set.size() < in.size()) {
std::unordered_map<char, ShapeElem> local_dims;
for (int j = 0; j < in.size(); ++j) {
auto dim = operands[i].shape(j);
auto inserted = local_dims.insert({in[j], dim});
if (!inserted.second) {
if (inserted.first->second != dim) {
std::ostringstream msg;
msg << "[" << fn_name << "] Dimensions of repeated subscripts "
<< "do not have the same size (" << inserted.first->second
<< " != " << dim << ").";
throw std::invalid_argument(msg.str());
}
}
}
}
for (int j = 0; j < in.size(); j++) {
auto c = in[j];
auto dim = operands[i].shape(j);
auto inserted = dim_map.insert({c, dim});
auto& in_dim = inserted.first->second;
if (dim != 1 && in_dim != 1 && in_dim != dim) {
std::ostringstream msg;
msg << "[" << fn_name << "] Cannot broadcast dimension " << j
<< " of input " << i << " with shape " << operands[i].shape()
<< " to size " << in_dim << ".";
throw std::invalid_argument(msg.str());
}
// Ensure the broadcasted size is used
in_dim = std::max(in_dim, dim);
}
}
size_t max_size = term_size(out_subscript, dim_map);
for (auto& in : in_subscripts) {
max_size = std::max(max_size, term_size(in, dim_map));
}
PathInfo path_info;
// Get the full naive cost
std::tie(path_info.naive_cost, path_info.naive_scaling) =
compute_cost_and_scaling(inputs, output, dim_map);
// Calculate the path
std::vector<PathNode> path;
if (inputs.size() <= 2) {
std::vector<int> positions(in_subscripts.size());
std::iota(positions.begin(), positions.end(), 0);
path.emplace_back(
std::move(inputs), std::move(output), std::move(positions));
} else {
std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) =
greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size);
// Set the final output subscript to the actual output
path.back().output = std::move(output);
}
return {path, path_info};
}
} // namespace
std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
const std::string& subscripts,
const std::vector<array>& operands) {
auto [path, path_info] =
einsum_path_helper(subscripts, operands, "einsum_path");
std::vector<std::vector<int>> pos_path;
for (auto& p : path) {
pos_path.push_back(p.positions);
}
std::ostringstream path_print;
path_print << " Complete contraction: " << subscripts << "\n"
<< " Naive scaling: " << path_info.naive_scaling << "\n"
<< " Optimized scaling: " << path_info.optimized_scaling
<< "\n"
<< " Naive FLOP count: " << path_info.naive_cost << "\n"
<< " Optimized FLOP count: " << path_info.optimized_cost << "\n";
// TODO add more info here
return {pos_path, path_print.str()};
}
array einsum(
const std::string& subscripts,
const std::vector<array>& operands,
StreamOrDevice s /* = {} */) {
auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum");
auto inputs = operands;
for (auto& node : path) {
preprocess_einsum_inputs(
node.inputs, node.output, node.positions, inputs, s);
if (can_dot(node.inputs, node.output)) {
auto& in_a = node.inputs[0];
auto& in_b = node.inputs[1];
auto& out = node.output;
std::vector<int> a_contract;
std::vector<int> a_batch;
std::vector<int> a_concat;
for (int i = 0; i < in_a.str.size(); ++i) {
auto c = in_a.str[i];
if (out.set.find(c) == out.set.end()) {
// Not in the output, contraction
a_contract.push_back(i);
} else if (in_b.set.find(c) != in_b.set.end()) {
// Not a contraction but in both inputs, batch dim
a_batch.push_back(i);
} else {
// Not a batch dim or contract dim, so concat dim
a_concat.push_back(i);
}
}
std::vector<int> b_contract;
std::vector<int> b_batch;
std::vector<int> b_concat;
for (auto a_i : a_contract) {
b_contract.push_back(in_b.str.find(in_a.str[a_i]));
}
for (auto a_i : a_batch) {
b_batch.push_back(in_b.str.find(in_a.str[a_i]));
}
for (int i = 0; i < in_b.str.size(); ++i) {
auto c = in_b.str[i];
if (out.set.find(c) != out.set.end() &&
in_a.set.find(c) == in_a.set.end()) {
b_concat.push_back(i);
}
}
auto& a = inputs[node.positions[0]];
auto& b = inputs[node.positions[1]];
std::unordered_map<char, int> char_map;
for (auto i : a_batch) {
char_map.insert({in_a.str[i], char_map.size()});
}
for (auto i : a_concat) {
char_map.insert({in_a.str[i], char_map.size()});
}
for (auto i : b_concat) {
char_map.insert({in_b.str[i], char_map.size()});
}
inputs.emplace_back(batch_tensordot(
a,
b,
std::move(a_contract),
std::move(a_batch),
std::move(a_concat),
std::move(b_contract),
std::move(b_batch),
std::move(b_concat),
s));
std::vector<int> reorder;
for (auto c : node.output.str) {
reorder.push_back(char_map[c]);
}
inputs.back() = transpose(inputs.back(), reorder, s);
} else {
inputs.emplace_back(
einsum_naive(node.inputs, node.output, node.positions, inputs, s));
}
// Positions are always sorted increasing, so start from the back
for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) {
inputs.erase(inputs.begin() + *it);
}
}
return inputs.front();
}
} // namespace mlx::core