// Copyright © 2024 Apple Inc. #include #include #include #include #include "mlx/einsum.h" #include "mlx/ops.h" namespace mlx::core { namespace { // The MLX einsum implementation is based on NumPy (which is based on // opt_einsum): // https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743 // https://github.com/dgasmith/opt_einsum using CharSet = std::unordered_set; // A helper struct to hold the string and set // representation of a subscript to avoid needing // to recompute the set struct Subscript { Subscript(std::string str, CharSet set) : str(std::move(str)), set(std::move(set)) {}; std::string str; CharSet set; }; struct PathInfo { size_t naive_cost; size_t naive_scaling; size_t optimized_cost; size_t optimized_scaling; size_t largest_term; }; struct PathNode { PathNode( std::vector inputs, Subscript output, std::vector positions) : inputs(std::move(inputs)), output(std::move(output)), positions(std::move(positions)) {}; std::vector inputs; Subscript output; std::vector positions; }; // Parse the comma separated subscripts into a vector of strings. If the // output subscripts are missing they are inferred. // // For example: // "ij,jk -> ik" becomes {{"ij", "jk"}, "ik"} // "ij,jk" becomes {{"ij", "jk"}, "ik"} std::pair, std::string> parse(std::string subscripts) { std::string lhs, rhs; // Start by removing all white space subscripts.erase( std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end()); if (auto pos = subscripts.find("->"); pos != std::string::npos) { // Explicit mode lhs = subscripts.substr(0, pos); rhs = subscripts.substr(pos + 2); } else { // Implicit mode: // - repeats are summed // - ellipses are placed in the beginning of the output // - remaining output axes are ordered alphabetically lhs = subscripts; std::unordered_map temp; for (auto& c : subscripts) { if (c == ',') { continue; } if (c == '.' && rhs.empty()) { rhs += "..."; continue; } auto inserted = temp.insert({c, 0}); inserted.first->second++; } for (auto& k : temp) { if (k.second == 1) { rhs += k.first; } } std::sort(rhs.begin(), rhs.end()); } std::vector input_list; std::stringstream ss(lhs); std::string token; while (getline(ss, token, ',')) { input_list.push_back(token); } return {input_list, rhs}; } // Check if two sets are disjoint bool disjoint(const CharSet& x, const CharSet& y) { for (auto& c : x) { if (y.find(c) != y.end()) { return false; } } return true; } template size_t term_size(const T& term, std::unordered_map dict) { size_t size = 1; for (auto c : term) { size *= dict[c]; } return size; } size_t flop_count( const CharSet& term, bool inner, int num_terms, std::unordered_map dict) { size_t size = term_size(term, dict); auto op_factor = 1; if ((num_terms - 1) > op_factor) { op_factor = num_terms - 1; } if (inner) { op_factor += 1; } return size * op_factor; } std::pair compute_cost_and_scaling( const std::vector& inputs, const Subscript& output, std::unordered_map dim_map) { CharSet contractions; for (auto& in : inputs) { contractions.insert(in.set.begin(), in.set.end()); } bool inner = false; for (auto c : contractions) { if (output.set.find(c) == output.set.end()) { inner = true; break; } } auto cost = flop_count(contractions, inner, inputs.size(), dim_map); return {cost, contractions.size()}; } std::tuple, size_t, int> greedy_path( std::vector inputs, const Subscript& output, std::unordered_map dim_map, size_t cost_limit, size_t memory_limit) { // Helper struct for building the greedy path struct Contraction { Contraction( size_t size, size_t cost, CharSet output, int dims, int x, int y) : size(size), cost(cost), output(std::move(output)), dims(dims), x(x), y(y) {}; int64_t size; // Size difference, can be negative size_t cost; CharSet output; int dims; // Number of dimensions in the contraction int x; int y; }; // Start by iterating over all possible combinations std::vector> pos_pairs; for (int i = 0; i < inputs.size(); ++i) { for (int j = i + 1; j < inputs.size(); ++j) { pos_pairs.emplace_back(i, j); } } std::vector path; std::vector possible_contractions; size_t path_cost = 0; int path_scaling = 0; auto num_in = inputs.size(); for (int i = 0; i < num_in - 1; ++i) { auto add_contraction = [&](int p1, int p2) { CharSet new_term; CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end()); contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end()); for (int i = 0; i < inputs.size(); i++) { if (i == p1 || i == p2) { continue; } auto& in = inputs[i].set; for (auto c : in) { if (contractions.find(c) != contractions.end()) { new_term.insert(c); } } } for (auto c : output.set) { if (contractions.find(c) != contractions.end()) { new_term.insert(c); } } // Ignore if: // - The size of the new result is greater than the memory limit // - The cost is larger than the naive cost auto new_size = term_size(new_term, dim_map); if (new_size > memory_limit) { return; } int64_t removed_size = term_size(inputs[p1].set, dim_map) + term_size(inputs[p2].set, dim_map) - new_size; bool inner = contractions.size() > new_term.size(); auto cost = flop_count(contractions, inner, 2, dim_map); if (path_cost + cost > cost_limit) { return; } possible_contractions.emplace_back( removed_size, cost, std::move(new_term), contractions.size(), p1, p2); }; for (auto& [p1, p2] : pos_pairs) { // Ignore outer products if (!disjoint(inputs[p1].set, inputs[p2].set)) { add_contraction(p1, p2); } } // If there's nothing in the contraction list, // go over the pairs again without ignoring outer products if (possible_contractions.empty()) { for (auto& [p1, p2] : pos_pairs) { add_contraction(p1, p2); } } if (possible_contractions.empty()) { // Default to naive einsum for the remaining inputs std::vector positions(inputs.size()); std::iota(positions.begin(), positions.end(), 0); auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map); path.emplace_back(std::move(inputs), output, std::move(positions)); path_cost += cost; path_scaling = std::max(scale, path_scaling); break; } // Find the best contraction auto& best = *std::min_element( possible_contractions.begin(), possible_contractions.end(), [](const auto& x, const auto& y) { return x.size > y.size || (x.size == y.size && x.cost < y.cost); }); path_scaling = std::max(best.dims, path_scaling); // Construct the output subscripts std::string out_str(best.output.begin(), best.output.end()); // TODO, sorting by dimension size seems suboptimal? std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) { return dim_map[x] < dim_map[y]; }); Subscript new_output(std::move(out_str), std::move(best.output)); // Add the chosen contraction to the path { std::vector in_terms; in_terms.push_back(std::move(inputs[best.x])); in_terms.push_back(std::move(inputs[best.y])); path.emplace_back( std::move(in_terms), new_output, std::vector{best.x, best.y}); } // Remove used terms inputs.erase(inputs.begin() + best.y); inputs.erase(inputs.begin() + best.x); // Add the new result inputs.push_back(std::move(new_output)); // Update the existing contractions based on the selected one std::vector updated_contractions; for (auto& contraction : possible_contractions) { // Drop contractions which contain either selected term if (contraction.x == best.x || contraction.x == best.y || contraction.y == best.x || contraction.y == best.y) { continue; } // Update the positions of other contractions int x = contraction.x - (contraction.x > best.x) - (contraction.x > best.y); int y = contraction.y - (contraction.y > best.x) - (contraction.y > best.y); contraction.x = x; contraction.y = y; updated_contractions.push_back(std::move(contraction)); } pos_pairs.clear(); for (int i = 0; i < inputs.size() - 1; ++i) { pos_pairs.emplace_back(i, inputs.size() - 1); } path_cost += best.cost; possible_contractions = std::move(updated_contractions); } return {path, path_cost, path_scaling}; } // Assumes inputs have already have had repeats and single axis sums collapsed bool can_dot(const std::vector& inputs, const Subscript& output) { if (inputs.size() != 2) { return false; } for (auto c : inputs[0].set) { // Use batched tensordot if anything is being contracted if (output.set.find(c) == output.set.end()) { return true; } } return false; } array batch_tensordot( array a, array b, std::vector a_contract, std::vector a_batch, std::vector a_concat, std::vector b_contract, std::vector b_batch, std::vector b_concat, StreamOrDevice s) { // Broadcast contracting dimensions { auto a_shape = a.shape(); auto b_shape = b.shape(); for (int i = 0; i < a_contract.size(); ++i) { auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i])); a_shape[a_contract[i]] = d; b_shape[b_contract[i]] = d; } a = broadcast_to(a, a_shape, s); b = broadcast_to(b, b_shape, s); } auto transpose_reshape = [&s]( const array& x, const std::vector& i, const std::vector& j, const std::vector& k) { std::vector reorder(i.begin(), i.end()); reorder.insert(reorder.end(), j.begin(), j.end()); reorder.insert(reorder.end(), k.begin(), k.end()); int size1 = 1; for (auto s : j) { size1 *= x.shape(s); } int size2 = 1; for (auto s : k) { size2 *= x.shape(s); } Shape shape; for (auto ax : i) { shape.push_back(x.shape(ax)); } shape.push_back(size1); shape.push_back(size2); return reshape(transpose(x, reorder, s), std::move(shape), s); }; Shape out_shape; for (auto ax : a_batch) { out_shape.push_back(a.shape(ax)); } for (auto ax : a_concat) { out_shape.push_back(a.shape(ax)); } for (auto ax : b_concat) { out_shape.push_back(b.shape(ax)); } a = transpose_reshape(a, a_batch, a_concat, a_contract); b = transpose_reshape(b, b_batch, b_contract, b_concat); return reshape(matmul(a, b, s), std::move(out_shape), s); } // Collapse repeated subscripts and return the resulting array. The subscript // is also updated in place. For example: // - Given an input with shape (4, 4) and subscript "ii", returns // the diagonal of shape (4,) and updates the subscript to "i". // - Given an input with shape (4, 2, 4, 2) and subscript "ijij", // returns an output with shape (4, 2) and updates the subscript // to "ij". array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) { // Build a list of (repeat chars, num repeats) auto& str = subscript.str; std::vector> repeats; std::string new_str; { std::string repeat_str; std::string no_repeat_str; std::unordered_map counts; for (int i = 0; i < str.size(); ++i) { auto [it, _] = counts.insert({str[i], 0}); it->second++; } for (auto& v : counts) { if (v.second > 1) { repeats.emplace_back(v.first, v.second); repeat_str += v.first; } } for (auto& c : str) { if (counts[c] == 1) { no_repeat_str += c; } } new_str = repeat_str + no_repeat_str; } // Build the inputs for gather auto slice_sizes = in.shape(); std::vector axes; std::vector indices; int n_expand = repeats.size(); for (auto [c, v] : repeats) { for (int i = 0; i < str.size(); ++i) { if (str[i] == c) { slice_sizes[i] = 1; axes.push_back(i); } } Shape idx_shape(n_expand--, 1); idx_shape[0] = in.shape(axes.back()); auto idx = reshape( arange(static_cast(in.shape(axes.back())), s), idx_shape, s); for (int i = 0; i < v; ++i) { indices.push_back(idx); } } in = gather(in, indices, axes, slice_sizes, s); // Update subscript string with removed dups str = new_str; // Squeeze singleton dimensions left over from the gather for (auto& ax : axes) { ax += indices[0].ndim(); } return squeeze(in, axes, s); } // Collapse repeat indices and sum single dimensions. // For example: // - "aa" becomes "a" // - "ij,jk->k" becoms "j,jk->k" void preprocess_einsum_inputs( std::vector& inputs, const Subscript& output, const std::vector& positions, std::vector& operands, StreamOrDevice s) { // Collapse repeat indices for (int i = 0; i < inputs.size(); ++i) { auto& in = inputs[i]; if (in.set.size() < in.str.size()) { operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s); } } // Sum indices that are only in a single input { std::unordered_map counts; for (auto& in : inputs) { for (auto c : in.set) { auto inserted = counts.insert({c, 0}); inserted.first->second++; } } for (auto c : output.set) { auto inserted = counts.insert({c, 0}); inserted.first->second++; } for (int i = 0; i < inputs.size(); ++i) { auto& in = inputs[i]; std::vector sum_axes; for (int ax = 0; ax < in.str.size(); ++ax) { if (counts[in.str[ax]] == 1) { sum_axes.push_back(ax); } } if (!sum_axes.empty()) { operands[positions[i]] = sum(operands[positions[i]], sum_axes, false, s); } for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) { in.set.erase(in.str[*it]); in.str.erase(in.str.begin() + *it); } } } } array einsum_naive( std::vector inputs, const Subscript& output, const std::vector& positions, std::vector operands, StreamOrDevice s) { // Map each character to an axis std::unordered_map char_to_ax; for (auto& in : inputs) { for (auto c : in.str) { char_to_ax.insert({c, char_to_ax.size()}); } } // Expand and transpose inputs as needed for (int i = 0; i < inputs.size(); ++i) { int pos = positions[i]; auto& op = operands[pos]; // Add missing dimensions at the end if (op.ndim() != char_to_ax.size()) { auto shape = op.shape(); shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1); op = reshape(op, std::move(shape), s); } // Transpose: // - Build a vector of (char, ax) pairs for the current input // - Sort the vector by the canonical axis in char_to_ax // - Extract the sorted axis to get transpose order std::vector> str_ax; for (auto c : inputs[i].str) { str_ax.emplace_back(c, str_ax.size()); } for (auto [c, ax] : char_to_ax) { if (inputs[i].set.find(c) == inputs[i].set.end()) { str_ax.emplace_back(c, str_ax.size()); } } std::sort( str_ax.begin(), str_ax.end(), [&char_to_ax](const auto& x, const auto& y) { return char_to_ax[x.first] < char_to_ax[y.first]; }); // Skip the transpose if not needed if (std::is_sorted( str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) { return x.second < y.second; })) { continue; } std::vector reorder; for (auto [c, ax] : str_ax) { reorder.push_back(ax); } op = transpose(op, reorder, s); } // Multiply and sum auto out = operands[positions[0]]; for (int i = 1; i < positions.size(); ++i) { out = multiply(out, operands[positions[i]], s); } std::vector sum_axes; for (auto [c, ax] : char_to_ax) { if (output.set.find(c) == output.set.end()) { sum_axes.push_back(ax); } } if (!sum_axes.empty()) { out = sum(out, sum_axes, false, s); } // Transpose output if needed std::vector reorder; for (auto c : output.str) { reorder.push_back(char_to_ax[c]); } for (auto& r : reorder) { int offset = 0; for (auto s : sum_axes) { if (r > s) { offset++; } } r -= offset; } return transpose(out, reorder, s); } std::pair, PathInfo> einsum_path_helper( const std::string& subscripts, const std::vector& operands, const std::string& fn_name) { if (operands.size() == 0) { std::ostringstream msg; msg << "[" << fn_name << "] At least one operand is required."; throw std::invalid_argument(msg.str()); } auto [in_subscripts, out_subscript] = parse(subscripts); if (operands.size() != in_subscripts.size()) { std::ostringstream msg; msg << "[" << fn_name << "] Number of operands, " << operands.size() << ", does not match number of input subscripts, " << in_subscripts.size(); throw std::invalid_argument(msg.str()); } // Expand ellipses // 1. Collect all the characters we can use for the missing axes. // 2. Go over each subscript and check if all the characters are either // alphanumeric or an ellipsis. // 3. Expand the ellipsis with as many characters from the unused ones as // necessary. We use the last N characters effectively prepending with // singleton dims for inputs with fewer dimensions. // 4. For the output use the maximum size of ellipsis that we encountered in // the input. CharSet used_chars(subscripts.begin(), subscripts.end()); std::string remaining_chars; remaining_chars.reserve(52 - used_chars.size()); for (char c = 'a'; c <= 'z'; c++) { if (used_chars.find(c) == used_chars.end()) { remaining_chars += c; } } for (char c = 'A'; c <= 'Z'; c++) { if (used_chars.find(c) == used_chars.end()) { remaining_chars += c; } } int max_ellipsis_length = 0; auto check_letters_and_expand_ellipsis = [&](auto& subscript, const array* operand) { bool have_ellipsis = false; int cnt_before = 0, cnt_after = 0; for (int i = 0; i < subscript.size(); i++) { if (!isalpha(subscript[i])) { if (i + 2 >= subscript.size() || subscript[i] != '.' || subscript[i + 1] != '.' || subscript[i + 2] != '.') { std::ostringstream msg; msg << "[" << fn_name << "] Subscripts must be letters, but got '" << subscript[i] << "'."; throw std::invalid_argument(msg.str()); } if (have_ellipsis) { std::ostringstream msg; msg << "[" << fn_name << "] Only one ellipsis per subscript is allowed but found more in '" << subscript << "'."; throw std::invalid_argument(msg.str()); } have_ellipsis = true; i += 2; continue; } if (have_ellipsis) { cnt_after++; } else { cnt_before++; } } if (have_ellipsis) { int ellipsis_length; if (operand != nullptr) { ellipsis_length = operand->ndim() - cnt_before - cnt_after; max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length); } else { ellipsis_length = max_ellipsis_length; } subscript.replace( subscript.begin() + cnt_before, subscript.begin() + cnt_before + 3, remaining_chars.end() - ellipsis_length, remaining_chars.end()); } }; for (int i = 0; i < operands.size(); i++) { check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i]); } check_letters_and_expand_ellipsis(out_subscript, nullptr); CharSet out_set(out_subscript.begin(), out_subscript.end()); if (out_set.size() != out_subscript.size()) { std::ostringstream msg; msg << "[" << fn_name << "] Repeat indices not allowed in output."; throw std::invalid_argument(msg.str()); } Subscript output(out_subscript, std::move(out_set)); std::unordered_map dim_map; std::vector inputs; for (int i = 0; i < in_subscripts.size(); ++i) { auto& in = in_subscripts[i]; CharSet in_set(in.begin(), in.end()); inputs.emplace_back(in, in_set); if (in.size() != operands[i].ndim()) { std::ostringstream msg; msg << "[" << fn_name << "] Invalid number of subscripts " << in.size() << " for input " << i << " with " << operands[i].ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } // Check repeat subscripts are valid if (in_set.size() < in.size()) { std::unordered_map local_dims; for (int j = 0; j < in.size(); ++j) { auto dim = operands[i].shape(j); auto inserted = local_dims.insert({in[j], dim}); if (!inserted.second) { if (inserted.first->second != dim) { std::ostringstream msg; msg << "[" << fn_name << "] Dimensions of repeated subscripts " << "do not have the same size (" << inserted.first->second << " != " << dim << ")."; throw std::invalid_argument(msg.str()); } } } } for (int j = 0; j < in.size(); j++) { auto c = in[j]; auto dim = operands[i].shape(j); auto inserted = dim_map.insert({c, dim}); auto& in_dim = inserted.first->second; if (dim != 1 && in_dim != 1 && in_dim != dim) { std::ostringstream msg; msg << "[" << fn_name << "] Cannot broadcast dimension " << j << " of input " << i << " with shape " << operands[i].shape() << " to size " << in_dim << "."; throw std::invalid_argument(msg.str()); } // Ensure the broadcasted size is used in_dim = std::max(in_dim, dim); } } size_t max_size = term_size(out_subscript, dim_map); for (auto& in : in_subscripts) { max_size = std::max(max_size, term_size(in, dim_map)); } PathInfo path_info; // Get the full naive cost std::tie(path_info.naive_cost, path_info.naive_scaling) = compute_cost_and_scaling(inputs, output, dim_map); // Calculate the path std::vector path; if (inputs.size() <= 2) { std::vector positions(in_subscripts.size()); std::iota(positions.begin(), positions.end(), 0); path.emplace_back( std::move(inputs), std::move(output), std::move(positions)); } else { std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) = greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size); // Set the final output subscript to the actual output path.back().output = std::move(output); } return {path, path_info}; } } // namespace std::pair>, std::string> einsum_path( const std::string& subscripts, const std::vector& operands) { auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum_path"); std::vector> pos_path; for (auto& p : path) { pos_path.push_back(p.positions); } std::ostringstream path_print; path_print << " Complete contraction: " << subscripts << "\n" << " Naive scaling: " << path_info.naive_scaling << "\n" << " Optimized scaling: " << path_info.optimized_scaling << "\n" << " Naive FLOP count: " << path_info.naive_cost << "\n" << " Optimized FLOP count: " << path_info.optimized_cost << "\n"; // TODO add more info here return {pos_path, path_print.str()}; } array einsum( const std::string& subscripts, const std::vector& operands, StreamOrDevice s /* = {} */) { auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum"); auto inputs = operands; for (auto& node : path) { preprocess_einsum_inputs( node.inputs, node.output, node.positions, inputs, s); if (can_dot(node.inputs, node.output)) { auto& in_a = node.inputs[0]; auto& in_b = node.inputs[1]; auto& out = node.output; std::vector a_contract; std::vector a_batch; std::vector a_concat; for (int i = 0; i < in_a.str.size(); ++i) { auto c = in_a.str[i]; if (out.set.find(c) == out.set.end()) { // Not in the output, contraction a_contract.push_back(i); } else if (in_b.set.find(c) != in_b.set.end()) { // Not a contraction but in both inputs, batch dim a_batch.push_back(i); } else { // Not a batch dim or contract dim, so concat dim a_concat.push_back(i); } } std::vector b_contract; std::vector b_batch; std::vector b_concat; for (auto a_i : a_contract) { b_contract.push_back(in_b.str.find(in_a.str[a_i])); } for (auto a_i : a_batch) { b_batch.push_back(in_b.str.find(in_a.str[a_i])); } for (int i = 0; i < in_b.str.size(); ++i) { auto c = in_b.str[i]; if (out.set.find(c) != out.set.end() && in_a.set.find(c) == in_a.set.end()) { b_concat.push_back(i); } } auto& a = inputs[node.positions[0]]; auto& b = inputs[node.positions[1]]; std::unordered_map char_map; for (auto i : a_batch) { char_map.insert({in_a.str[i], char_map.size()}); } for (auto i : a_concat) { char_map.insert({in_a.str[i], char_map.size()}); } for (auto i : b_concat) { char_map.insert({in_b.str[i], char_map.size()}); } inputs.emplace_back(batch_tensordot( a, b, std::move(a_contract), std::move(a_batch), std::move(a_concat), std::move(b_contract), std::move(b_batch), std::move(b_concat), s)); std::vector reorder; for (auto c : node.output.str) { reorder.push_back(char_map[c]); } inputs.back() = transpose(inputs.back(), reorder, s); } else { inputs.emplace_back( einsum_naive(node.inputs, node.output, node.positions, inputs, s)); } // Positions are always sorted increasing, so start from the back for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) { inputs.erase(inputs.begin() + *it); } } return inputs.front(); } } // namespace mlx::core