From baf9fa5f42e9c1e362cca799d58c3dce0cdaec88 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 25 Jul 2024 09:36:44 -0700 Subject: [PATCH] Einsum (#1269) * einsum initial * fix comma break * sum axis was wrong * small cleanups * python binding * changed bindings to resemble numpy * remove todo comment * comment changes * add count of operands/inputs * fail fast if operands list is empty * ignore comma if no output * einsum path matching numpy * getting somewhere with path * remove print * it passes the first test * moved einsum tests to seperate file * seperated einsum path * moved einsum naive * remove space from equation * fast fail if no operands passed * update tests and remove printf * small cleanup * some more cleanups * removed python helper file * ack * utilize std for finding min in vector * duplicate def * remove the tuple as it was unreadable * moved einsum_naive back to ops * remaining isn't needed * avoid creating another set * cleanup * greedy path, start of naive einsum * more einsum * fix some bugs * some more fixes, tests pass * benchmark * some simplify * fix einsum and test Co-authored-by: Angelos Katharopoulos * add a bunch more tests and fix a bunch more bugs * some docs nits --------- Co-authored-by: dc-dc-dc Co-authored-by: Angelos Katharopoulos --- ACKNOWLEDGMENTS.md | 2 +- benchmarks/python/einsum_bench.py | 84 +++ docs/src/python/ops.rst | 2 + mlx/CMakeLists.txt | 1 + mlx/einsum.cpp | 859 ++++++++++++++++++++++++++++++ mlx/einsum.h | 22 + mlx/mlx.h | 1 + mlx/random.cpp | 5 +- python/src/ops.cpp | 143 +++-- python/src/random.cpp | 49 +- python/tests/test_einsum.py | 318 +++++++++++ tests/CMakeLists.txt | 1 + tests/einsum_tests.cpp | 76 +++ 13 files changed, 1498 insertions(+), 65 deletions(-) create mode 100644 benchmarks/python/einsum_bench.py create mode 100644 mlx/einsum.cpp create mode 100644 mlx/einsum.h create mode 100644 python/tests/test_einsum.py create mode 100644 tests/einsum_tests.cpp diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index b3183f4f8..265eca97f 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. diff --git a/benchmarks/python/einsum_bench.py b/benchmarks/python/einsum_bench.py new file mode 100644 index 000000000..a70311aec --- /dev/null +++ b/benchmarks/python/einsum_bench.py @@ -0,0 +1,84 @@ +# Copyright © 2024 Apple Inc. + +import time + +import mlx.core as mx +import numpy as np + + +def timeit(fn, its=100, args=[]): + for _ in range(5): + fn(*args) + tic = time.perf_counter() + for _ in range(its): + fn(*args) + toc = time.perf_counter() + return 1e3 * (toc - tic) / its + + +def time_little_einsum_path(): + subscripts = "ik,kj->ij" + x = mx.ones((32, 32)) + y = mx.ones((32, 32)) + mx_time = timeit(mx.einsum_path, args=(subscripts, x, y)) + + x = np.array(x) + y = np.array(y) + np_time = timeit(np.einsum_path, args=(subscripts, x, y)) + print("Timing little einsum path...") + print(f"MLX ... {mx_time:.3f} ms") + print(f"NumPy... {np_time:.3f} ms") + + +def time_big_einsum_path(): + chars = list("abcdefgh") + char_to_dim = {c: v for v, c in enumerate(chars)} + + num_inputs = 10 + inputs = [] + subscripts = [] + for _ in range(num_inputs): + subscript = np.random.choice(chars, size=5, replace=False).tolist() + subscripts.append("".join(subscript)) + inputs.append(np.ones(list(char_to_dim[c] for c in subscript))) + subscripts = ",".join(subscripts) + + np_time = timeit(np.einsum_path, args=(subscripts, *inputs)) + + inputs = [mx.array(x) for x in inputs] + mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs)) + print("Timing big einsum path...") + print(f"MLX ... {mx_time:.3f} ms") + print(f"NumPy... {np_time:.3f} ms") + + +def time_attention(): + def regular_attention(x): + # shape [batch, sequence, num_heads, head_dim] + queries, keys, values = x, x, x + scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1) + scores = mx.softmax(scores, axis=-1) + output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2) + mx.eval(output) + + def einsum_attention(x): + # shape [batch, sequence, num_heads, head_dim] + queries, keys, values = x, x, x + scores = mx.einsum("itjk,iujk->ijtu", queries, keys) + scores = mx.softmax(scores, axis=-1) + output = mx.einsum("ijtu,iujk->itjk", scores, values) + mx.eval(output) + + x = mx.random.uniform(shape=(8, 512, 32, 128)) + + regular_time = timeit(regular_attention, args=(x,)) + ein_time = timeit(einsum_attention, args=(x,)) + print("Timing einsum attention...") + print(f"Regular ... {regular_time:.3f} ms") + print(f"Einsum ... {ein_time:.3f} ms") + + +if __name__ == "__main__": + time_little_einsum_path() + time_big_einsum_path() + time_attention() diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 0d75f7d62..6109a3e5f 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -57,6 +57,8 @@ Operations diagonal divide divmod + einsum + einsum_path equal erf erfinv diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 14c24896d..f62772571 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp diff --git a/mlx/einsum.cpp b/mlx/einsum.cpp new file mode 100644 index 000000000..f809cc1e0 --- /dev/null +++ b/mlx/einsum.cpp @@ -0,0 +1,859 @@ +// Copyright © 2024 Apple Inc. +#include +#include +#include +#include + +#include "mlx/einsum.h" +#include "mlx/ops.h" + +namespace mlx::core { + +namespace { + +// The MLX einsum implementation is based on NumPy (which is based on +// opt_einsum): +// https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743 +// https://github.com/dgasmith/opt_einsum + +using CharSet = std::unordered_set; + +// A helper struct to hold the string and set +// representation of a subscript to avoid needing +// to recompute the set +struct Subscript { + Subscript(std::string str, CharSet set) + : str(std::move(str)), set(std::move(set)) {}; + std::string str; + CharSet set; +}; + +struct PathInfo { + size_t naive_cost; + size_t naive_scaling; + size_t optimized_cost; + size_t optimized_scaling; + size_t largest_term; +}; + +struct PathNode { + PathNode( + std::vector inputs, + Subscript output, + std::vector positions) + : inputs(std::move(inputs)), + output(std::move(output)), + positions(std::move(positions)) {}; + + std::vector inputs; + Subscript output; + + std::vector positions; +}; + +// Parse the comma separated subscripts into a vector of strings. If the +// output subscripts are missing they are inferred. +// +// For example: +// "ij,jk -> ik" becomes {{"ij", "jk"}, "ik"} +// "ij,jk" becomes {{"ij", "jk"}, "ik"} +std::pair, std::string> parse(std::string subscripts) { + std::string lhs, rhs; + + // Start by removing all white space + subscripts.erase( + std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end()); + + if (auto pos = subscripts.find("->"); pos != std::string::npos) { + // Explicit mode + lhs = subscripts.substr(0, pos); + rhs = subscripts.substr(pos + 2); + } else { + // Implicit mode: + // - repeats are summed + // - remaining output axes are ordered alphabetically + lhs = subscripts; + std::unordered_map temp; + for (auto& c : subscripts) { + if (c == ',') { + continue; + } + auto inserted = temp.insert({c, 0}); + inserted.first->second++; + } + for (auto& k : temp) { + if (k.second == 1) { + rhs += k.first; + } + } + std::sort(rhs.begin(), rhs.end()); + } + std::vector input_list; + std::stringstream ss(lhs); + std::string token; + while (getline(ss, token, ',')) { + input_list.push_back(token); + } + return {input_list, rhs}; +} + +// Check if two sets are disjoint +bool disjoint(const CharSet& x, const CharSet& y) { + for (auto& c : x) { + if (y.find(c) != y.end()) { + return false; + } + } + return true; +} + +template +size_t term_size(const T& term, std::unordered_map dict) { + size_t size = 1; + for (auto c : term) { + size *= dict[c]; + } + return size; +} + +size_t flop_count( + const CharSet& term, + bool inner, + int num_terms, + std::unordered_map dict) { + size_t size = term_size(term, dict); + auto op_factor = 1; + if ((num_terms - 1) > op_factor) { + op_factor = num_terms - 1; + } + if (inner) { + op_factor += 1; + } + return size * op_factor; +} + +std::pair compute_cost_and_scaling( + const std::vector& inputs, + const Subscript& output, + std::unordered_map dim_map) { + CharSet contractions; + for (auto& in : inputs) { + contractions.insert(in.set.begin(), in.set.end()); + } + + bool inner = false; + for (auto c : contractions) { + if (output.set.find(c) == output.set.end()) { + inner = true; + break; + } + } + auto cost = flop_count(contractions, inner, inputs.size(), dim_map); + return {cost, contractions.size()}; +} + +std::tuple, size_t, int> greedy_path( + std::vector inputs, + const Subscript& output, + std::unordered_map dim_map, + size_t cost_limit, + size_t memory_limit) { + // Helper struct for building the greedy path + struct Contraction { + Contraction( + size_t size, + size_t cost, + CharSet output, + int dims, + int x, + int y) + : size(size), + cost(cost), + output(std::move(output)), + dims(dims), + x(x), + y(y) {}; + + int64_t size; // Size difference, can be negative + size_t cost; + CharSet output; + int dims; // Number of dimensions in the contraction + int x; + int y; + }; + + // Start by iterating over all possible combinations + std::vector> pos_pairs; + for (int i = 0; i < inputs.size(); ++i) { + for (int j = i + 1; j < inputs.size(); ++j) { + pos_pairs.emplace_back(i, j); + } + } + + std::vector path; + std::vector possible_contractions; + size_t path_cost = 0; + int path_scaling = 0; + auto num_in = inputs.size(); + for (int i = 0; i < num_in - 1; ++i) { + auto add_contraction = [&](int p1, int p2) { + CharSet new_term; + CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end()); + contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end()); + for (int i = 0; i < inputs.size(); i++) { + if (i == p1 || i == p2) { + continue; + } + auto& in = inputs[i].set; + for (auto c : in) { + if (contractions.find(c) != contractions.end()) { + new_term.insert(c); + } + } + } + for (auto c : output.set) { + if (contractions.find(c) != contractions.end()) { + new_term.insert(c); + } + } + + // Ignore if: + // - The size of the new result is greater than the memory limit + // - The cost is larger than the naive cost + auto new_size = term_size(new_term, dim_map); + if (new_size > memory_limit) { + return; + } + int64_t removed_size = term_size(inputs[p1].set, dim_map) + + term_size(inputs[p2].set, dim_map) - new_size; + + bool inner = contractions.size() > new_term.size(); + auto cost = flop_count(contractions, inner, 2, dim_map); + if (path_cost + cost > cost_limit) { + return; + } + possible_contractions.emplace_back( + removed_size, cost, std::move(new_term), contractions.size(), p1, p2); + }; + + for (auto& [p1, p2] : pos_pairs) { + // Ignore outer products + if (!disjoint(inputs[p1].set, inputs[p2].set)) { + add_contraction(p1, p2); + } + } + + // If there's nothing in the contraction list, + // go over the pairs again without ignoring outer products + if (possible_contractions.empty()) { + for (auto& [p1, p2] : pos_pairs) { + add_contraction(p1, p2); + } + } + + if (possible_contractions.empty()) { + // Default to naive einsum for the remaining inputs + std::vector positions(inputs.size()); + std::iota(positions.begin(), positions.end(), 0); + auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map); + path.emplace_back(std::move(inputs), output, std::move(positions)); + + path_cost += cost; + path_scaling = std::max(scale, path_scaling); + break; + } + + // Find the best contraction + auto& best = *std::min_element( + possible_contractions.begin(), + possible_contractions.end(), + [](const auto& x, const auto& y) { + return x.size > y.size || (x.size == y.size && x.cost < y.cost); + }); + path_scaling = std::max(best.dims, path_scaling); + + // Construct the output subscripts + std::string out_str(best.output.begin(), best.output.end()); + // TODO, sorting by dimension size seems suboptimal? + std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) { + return dim_map[x] < dim_map[y]; + }); + Subscript new_output(std::move(out_str), std::move(best.output)); + + // Add the chosen contraction to the path + { + std::vector in_terms; + in_terms.push_back(std::move(inputs[best.x])); + in_terms.push_back(std::move(inputs[best.y])); + path.emplace_back( + std::move(in_terms), new_output, std::vector{best.x, best.y}); + } + // Remove used terms + inputs.erase(inputs.begin() + best.y); + inputs.erase(inputs.begin() + best.x); + + // Add the new result + inputs.push_back(std::move(new_output)); + + // Update the existing contractions based on the selected one + std::vector updated_contractions; + for (auto& contraction : possible_contractions) { + // Drop contractions which contain either selected term + if (contraction.x == best.x || contraction.x == best.y || + contraction.y == best.x || contraction.y == best.y) { + continue; + } + + // Update the positions of other contractions + int x = + contraction.x - (contraction.x > best.x) - (contraction.x > best.y); + int y = + contraction.y - (contraction.y > best.x) - (contraction.y > best.y); + contraction.x = x; + contraction.y = y; + updated_contractions.push_back(std::move(contraction)); + } + + pos_pairs.clear(); + for (int i = 0; i < inputs.size() - 1; ++i) { + pos_pairs.emplace_back(i, inputs.size() - 1); + } + path_cost += best.cost; + + possible_contractions = std::move(updated_contractions); + } + return {path, path_cost, path_scaling}; +} + +// Assumes inputs have already have had repeats and single axis sums collapsed +bool can_dot(const std::vector& inputs, const Subscript& output) { + if (inputs.size() != 2) { + return false; + } + + for (auto c : inputs[0].set) { + // Use batched tensordot if anything is being contracted + if (output.set.find(c) == output.set.end()) { + return true; + } + } + return false; +} + +array batch_tensordot( + array a, + array b, + std::vector a_contract, + std::vector a_batch, + std::vector a_concat, + std::vector b_contract, + std::vector b_batch, + std::vector b_concat, + StreamOrDevice s) { + // Broadcast contracting dimensions + { + auto a_shape = a.shape(); + auto b_shape = b.shape(); + for (int i = 0; i < a_contract.size(); ++i) { + auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i])); + a_shape[a_contract[i]] = d; + b_shape[b_contract[i]] = d; + } + a = broadcast_to(a, a_shape, s); + b = broadcast_to(b, b_shape, s); + } + auto transpose_reshape = [&s]( + const array& x, + const std::vector& i, + const std::vector& j, + const std::vector& k) { + std::vector reorder(i.begin(), i.end()); + reorder.insert(reorder.end(), j.begin(), j.end()); + reorder.insert(reorder.end(), k.begin(), k.end()); + + int size1 = 1; + for (auto s : j) { + size1 *= x.shape(s); + } + + int size2 = 1; + for (auto s : k) { + size2 *= x.shape(s); + } + + std::vector shape; + for (auto ax : i) { + shape.push_back(x.shape(ax)); + } + shape.push_back(size1); + shape.push_back(size2); + + return reshape(transpose(x, reorder, s), std::move(shape), s); + }; + + std::vector out_shape; + for (auto ax : a_batch) { + out_shape.push_back(a.shape(ax)); + } + for (auto ax : a_concat) { + out_shape.push_back(a.shape(ax)); + } + for (auto ax : b_concat) { + out_shape.push_back(b.shape(ax)); + } + + a = transpose_reshape(a, a_batch, a_concat, a_contract); + b = transpose_reshape(b, b_batch, b_contract, b_concat); + + return reshape(matmul(a, b, s), std::move(out_shape), s); +} + +// Collapse repeated subscripts and return the resulting array. The subscript +// is also updated in place. For example: +// - Given an input with shape (4, 4) and subscript "ii", returns +// the diagonal of shape (4,) and updates the subscript to "i". +// - Given an input with shape (4, 2, 4, 2) and subscript "ijij", +// returns an output with shape (4, 2) and updates the subscript +// to "ij". +array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) { + // Build a list of (repeat chars, num repeats) + auto& str = subscript.str; + std::vector> repeats; + std::string new_str; + { + std::string repeat_str; + std::string no_repeat_str; + std::unordered_map counts; + for (int i = 0; i < str.size(); ++i) { + auto [it, _] = counts.insert({str[i], 0}); + it->second++; + } + + for (auto& v : counts) { + if (v.second > 1) { + repeats.emplace_back(v.first, v.second); + repeat_str += v.first; + } + } + for (auto& c : str) { + if (counts[c] == 1) { + no_repeat_str += c; + } + } + new_str = repeat_str + no_repeat_str; + } + + // Build the inputs for gather + auto slice_sizes = in.shape(); + std::vector axes; + std::vector indices; + int n_expand = repeats.size(); + for (auto [c, v] : repeats) { + for (int i = 0; i < str.size(); ++i) { + if (str[i] == c) { + slice_sizes[i] = 1; + axes.push_back(i); + } + } + std::vector idx_shape(n_expand--, 1); + idx_shape[0] = in.shape(axes.back()); + auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s); + for (int i = 0; i < v; ++i) { + indices.push_back(idx); + } + } + + in = gather(in, indices, axes, slice_sizes, s); + + // Update subscript string with removed dups + str = new_str; + + // Squeeze singleton dimensions left over from the gather + for (auto& ax : axes) { + ax += indices[0].ndim(); + } + + return squeeze(in, axes, s); +} + +// Collapse repeat indices and sum single dimensions. +// For example: +// - "aa" becomes "a" +// - "ij,jk->k" becoms "j,jk->k" +void preprocess_einsum_inputs( + std::vector& inputs, + const Subscript& output, + const std::vector& positions, + std::vector& operands, + StreamOrDevice s) { + // Collapse repeat indices + for (int i = 0; i < inputs.size(); ++i) { + auto& in = inputs[i]; + if (in.set.size() < in.str.size()) { + operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s); + } + } + + // Sum indices that are only in a single input + { + std::unordered_map counts; + for (auto& in : inputs) { + for (auto c : in.set) { + auto inserted = counts.insert({c, 0}); + inserted.first->second++; + } + } + for (auto c : output.set) { + auto inserted = counts.insert({c, 0}); + inserted.first->second++; + } + for (int i = 0; i < inputs.size(); ++i) { + auto& in = inputs[i]; + std::vector sum_axes; + for (int ax = 0; ax < in.str.size(); ++ax) { + if (counts[in.str[ax]] == 1) { + sum_axes.push_back(ax); + } + } + if (!sum_axes.empty()) { + operands[positions[i]] = + sum(operands[positions[i]], sum_axes, false, s); + } + for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) { + in.set.erase(in.str[*it]); + in.str.erase(in.str.begin() + *it); + } + } + } +} + +array einsum_naive( + std::vector inputs, + const Subscript& output, + const std::vector& positions, + std::vector operands, + StreamOrDevice s) { + // Map each character to an axis + std::unordered_map char_to_ax; + for (auto& in : inputs) { + for (auto c : in.str) { + char_to_ax.insert({c, char_to_ax.size()}); + } + } + + // Expand and transpose inputs as needed + for (int i = 0; i < inputs.size(); ++i) { + int pos = positions[i]; + auto& op = operands[pos]; + + // Add missing dimensions at the end + if (op.ndim() != char_to_ax.size()) { + auto shape = op.shape(); + shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1); + op = reshape(op, std::move(shape), s); + } + + // Transpose: + // - Build a vector of (char, ax) pairs for the current input + // - Sort the vector by the canonical axis in char_to_ax + // - Extract the sorted axis to get transpose order + std::vector> str_ax; + for (auto c : inputs[i].str) { + str_ax.emplace_back(c, str_ax.size()); + } + for (auto [c, ax] : char_to_ax) { + if (inputs[i].set.find(c) == inputs[i].set.end()) { + str_ax.emplace_back(c, str_ax.size()); + } + } + std::sort( + str_ax.begin(), + str_ax.end(), + [&char_to_ax](const auto& x, const auto& y) { + return char_to_ax[x.first] < char_to_ax[y.first]; + }); + + // Skip the transpose if not needed + if (std::is_sorted( + str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) { + return x.second < y.second; + })) { + continue; + } + + std::vector reorder; + for (auto [c, ax] : str_ax) { + reorder.push_back(ax); + } + op = transpose(op, reorder, s); + } + + // Multiply and sum + auto out = operands[positions[0]]; + for (int i = 1; i < positions.size(); ++i) { + out = multiply(out, operands[positions[i]], s); + } + std::vector sum_axes; + for (auto [c, ax] : char_to_ax) { + if (output.set.find(c) == output.set.end()) { + sum_axes.push_back(ax); + } + } + if (!sum_axes.empty()) { + out = sum(out, sum_axes, false, s); + } + + // Transpose output if needed + std::vector reorder; + for (auto c : output.str) { + reorder.push_back(char_to_ax[c]); + } + for (auto& r : reorder) { + int offset = 0; + for (auto s : sum_axes) { + if (r > s) { + offset++; + } + } + r -= offset; + } + return transpose(out, reorder, s); +} + +std::pair, PathInfo> einsum_path_helper( + const std::string& subscripts, + const std::vector& operands, + const std::string& fn_name) { + if (operands.size() == 0) { + std::ostringstream msg; + msg << "[" << fn_name << "] At least one operand is required."; + throw std::invalid_argument(msg.str()); + } + + auto [in_subscripts, out_subscript] = parse(subscripts); + + if (operands.size() != in_subscripts.size()) { + std::ostringstream msg; + msg << "[" << fn_name << "] Number of operands, " << operands.size() + << ", does not match number of input subscripts, " + << in_subscripts.size(); + throw std::invalid_argument(msg.str()); + } + + auto check_letters = [&](const auto& subscript) { + for (auto c : subscript) { + if (!isalpha(c)) { + std::ostringstream msg; + msg << "[" << fn_name << "] Subscripts must be letters, but got '" << c + << "'."; + throw std::invalid_argument(msg.str()); + } + } + }; + for (auto& in : in_subscripts) { + check_letters(in); + } + check_letters(out_subscript); + + CharSet out_set(out_subscript.begin(), out_subscript.end()); + if (out_set.size() != out_subscript.size()) { + std::ostringstream msg; + msg << "[" << fn_name << "] Repeat indices not allowed in output."; + throw std::invalid_argument(msg.str()); + } + Subscript output(out_subscript, std::move(out_set)); + + std::unordered_map dim_map; + std::vector inputs; + for (int i = 0; i < in_subscripts.size(); ++i) { + auto& in = in_subscripts[i]; + CharSet in_set(in.begin(), in.end()); + inputs.emplace_back(in, in_set); + + if (in.size() != operands[i].ndim()) { + std::ostringstream msg; + msg << "[" << fn_name << "] Invalid number of subscripts " << in.size() + << " for input " << i << " with " << operands[i].ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + // Check repeat subscripts are valid + if (in_set.size() < in.size()) { + std::unordered_map local_dims; + for (int j = 0; j < in.size(); ++j) { + auto dim = operands[i].shape(j); + auto inserted = local_dims.insert({in[j], dim}); + if (!inserted.second) { + if (inserted.first->second != dim) { + std::ostringstream msg; + msg << "[" << fn_name << "] Dimensions of repeated subscripts " + << "do not have the same size (" << inserted.first->second + << " != " << dim << ")."; + throw std::invalid_argument(msg.str()); + } + } + } + } + + for (int j = 0; j < in.size(); j++) { + auto c = in[j]; + auto dim = operands[i].shape(j); + auto inserted = dim_map.insert({c, dim}); + auto& in_dim = inserted.first->second; + if (dim != 1 && in_dim != 1 && in_dim != dim) { + std::ostringstream msg; + msg << "[" << fn_name << "] Cannot broadcast dimension " << j + << " of input " << i << " with shape " << operands[i].shape() + << " to size " << in_dim << "."; + throw std::invalid_argument(msg.str()); + } + // Ensure the broadcasted size is used + in_dim = std::max(in_dim, dim); + } + } + + size_t max_size = term_size(out_subscript, dim_map); + for (auto& in : in_subscripts) { + max_size = std::max(max_size, term_size(in, dim_map)); + } + + PathInfo path_info; + + // Get the full naive cost + std::tie(path_info.naive_cost, path_info.naive_scaling) = + compute_cost_and_scaling(inputs, output, dim_map); + + // Calculate the path + std::vector path; + if (inputs.size() <= 2) { + std::vector positions(in_subscripts.size()); + std::iota(positions.begin(), positions.end(), 0); + path.emplace_back( + std::move(inputs), std::move(output), std::move(positions)); + } else { + std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) = + greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size); + // Set the final output subscript to the actual output + path.back().output = std::move(output); + } + return {path, path_info}; +} + +} // namespace + +std::pair>, std::string> einsum_path( + const std::string& subscripts, + const std::vector& operands) { + auto [path, path_info] = + einsum_path_helper(subscripts, operands, "einsum_path"); + + std::vector> pos_path; + for (auto& p : path) { + pos_path.push_back(p.positions); + } + + std::ostringstream path_print; + path_print << " Complete contraction: " << subscripts << "\n" + << " Naive scaling: " << path_info.naive_scaling << "\n" + << " Optimized scaling: " << path_info.optimized_scaling + << "\n" + << " Naive FLOP count: " << path_info.naive_cost << "\n" + << " Optimized FLOP count: " << path_info.optimized_cost << "\n"; + // TODO add more info here + return {pos_path, path_print.str()}; +} + +array einsum( + const std::string& subscripts, + const std::vector& operands, + StreamOrDevice s /* = {} */) { + auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum"); + auto inputs = operands; + for (auto& node : path) { + preprocess_einsum_inputs( + node.inputs, node.output, node.positions, inputs, s); + + if (can_dot(node.inputs, node.output)) { + auto& in_a = node.inputs[0]; + auto& in_b = node.inputs[1]; + auto& out = node.output; + + std::vector a_contract; + std::vector a_batch; + std::vector a_concat; + for (int i = 0; i < in_a.str.size(); ++i) { + auto c = in_a.str[i]; + if (out.set.find(c) == out.set.end()) { + // Not in the output, contraction + a_contract.push_back(i); + } else if (in_b.set.find(c) != in_b.set.end()) { + // Not a contraction but in both inputs, batch dim + a_batch.push_back(i); + } else { + // Not a batch dim or contract dim, so concat dim + a_concat.push_back(i); + } + } + + std::vector b_contract; + std::vector b_batch; + std::vector b_concat; + for (auto a_i : a_contract) { + b_contract.push_back(in_b.str.find(in_a.str[a_i])); + } + for (auto a_i : a_batch) { + b_batch.push_back(in_b.str.find(in_a.str[a_i])); + } + for (int i = 0; i < in_b.str.size(); ++i) { + auto c = in_b.str[i]; + if (out.set.find(c) != out.set.end() && + in_a.set.find(c) == in_a.set.end()) { + b_concat.push_back(i); + } + } + + auto& a = inputs[node.positions[0]]; + auto& b = inputs[node.positions[1]]; + + std::unordered_map char_map; + for (auto i : a_batch) { + char_map.insert({in_a.str[i], char_map.size()}); + } + for (auto i : a_concat) { + char_map.insert({in_a.str[i], char_map.size()}); + } + for (auto i : b_concat) { + char_map.insert({in_b.str[i], char_map.size()}); + } + inputs.emplace_back(batch_tensordot( + a, + b, + std::move(a_contract), + std::move(a_batch), + std::move(a_concat), + std::move(b_contract), + std::move(b_batch), + std::move(b_concat), + s)); + + std::vector reorder; + for (auto c : node.output.str) { + reorder.push_back(char_map[c]); + } + inputs.back() = transpose(inputs.back(), reorder, s); + + } else { + inputs.emplace_back( + einsum_naive(node.inputs, node.output, node.positions, inputs, s)); + } + + // Positions are always sorted increasing, so start from the back + for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) { + inputs.erase(inputs.begin() + *it); + } + } + return inputs.front(); +} + +} // namespace mlx::core diff --git a/mlx/einsum.h b/mlx/einsum.h new file mode 100644 index 000000000..f57e9a77c --- /dev/null +++ b/mlx/einsum.h @@ -0,0 +1,22 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/utils.h" + +namespace mlx::core { + +std::pair>, std::string> einsum_path( + const std::string& subscripts, + const std::vector& operands); + +array einsum( + const std::string& subscripts, + const std::vector& operands, + StreamOrDevice s = {}); + +} // namespace mlx::core diff --git a/mlx/mlx.h b/mlx/mlx.h index d8fe150ed..448319d3f 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -8,6 +8,7 @@ #include "mlx/device.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/ops.h" +#include "mlx/einsum.h" #include "mlx/fast.h" #include "mlx/fft.h" #include "mlx/io.h" diff --git a/mlx/random.cpp b/mlx/random.cpp index 590ca375e..8dae8964b 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -444,8 +444,9 @@ array laplace( auto samples = uniform(low, high, shape, dtype, key, stream); // Use inverse CDF to generate Laplacian noise samples = multiply( - sign(samples), - log1p(multiply(array(-1.0f, dtype), abs(samples))), + sign(samples, stream), + log1p( + multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream), stream); if (scale != 1.0) { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f2b10a5dd..be396f308 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -12,6 +12,7 @@ #include #include +#include "mlx/einsum.h" #include "mlx/ops.h" #include "mlx/utils.h" #include "python/src/load.h" @@ -40,15 +41,6 @@ double scalar_to_double(Scalar s) { } void init_ops(nb::module_& m) { - // TODO, remove deprecation errors in a future release - m.def("block_sparse_mm", [](nb::args, nb::kwargs) { - throw std::invalid_argument( - "block_sparse_mm is deprecated. Please use gather_mm which has the same signature"); - }); - m.def("block_sparse_qmm", [](nb::args, nb::kwargs) { - throw std::invalid_argument( - "block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature"); - }); m.def( "reshape", &reshape, @@ -1238,7 +1230,8 @@ void init_ops(nb::module_& m) { a (array): Input array. Returns: - array: The unchanged input ``a`` but without gradient flowing + array: + The unchanged input ``a`` but without gradient flowing through it. )pbdoc"); m.def( @@ -2936,6 +2929,9 @@ void init_ops(nb::module_& m) { reverse (bool): Perform the cumulative sum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + + Returns: + array: The output array. )pbdoc"); m.def( "cumprod", @@ -2969,6 +2965,9 @@ void init_ops(nb::module_& m) { reverse (bool): Perform the cumulative product in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + + Returns: + array: The output array. )pbdoc"); m.def( "cummax", @@ -3002,6 +3001,9 @@ void init_ops(nb::module_& m) { reverse (bool): Perform the cumulative maximum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + + Returns: + array: The output array. )pbdoc"); m.def( "cummin", @@ -3035,6 +3037,9 @@ void init_ops(nb::module_& m) { reverse (bool): Perform the cumulative minimum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + + Returns: + array: The output array. )pbdoc"); m.def( "conj", @@ -3052,6 +3057,9 @@ void init_ops(nb::module_& m) { Args: a (array): Input array + + Returns: + array: The output array. )pbdoc"); m.def( "conjugate", @@ -3069,6 +3077,9 @@ void init_ops(nb::module_& m) { Args: a (array): Input array + + Returns: + array: The output array. )pbdoc"); m.def( "convolve", @@ -3492,14 +3503,11 @@ void init_ops(nb::module_& m) { Args: file (file, str): File in which the array is saved. format (str, optional): Format of the file. If ``None``, the - format - is inferred from the file extension. Supported formats: - ``npy``, - ``npz``, and ``safetensors``. Default: ``None``. + format is inferred from the file extension. Supported formats: + ``npy``, ``npz``, and ``safetensors``. Default: ``None``. return_metadata (bool, optional): Load the metadata for formats - which - support matadata. The metadata will be returned as an - additional dictionary. + which support matadata. The metadata will be returned as an + additional dictionary. Default: ``False``. Returns: array or dict: A single array if loading from a ``.npy`` file or a dict @@ -3551,9 +3559,9 @@ void init_ops(nb::module_& m) { Args: file (file, str): File in which the array is saved. arrays (dict(str, array)): The dictionary of names to arrays to - be saved. metadata (dict(str, Union[array, str, list(str)])): - The dictionary of - metadata to be saved. The values can be a scalar or 1D + be saved. + metadata (dict(str, Union[array, str, list(str)])): The dictionary + of metadata to be saved. The values can be a scalar or 1D obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`. )pbdoc"); m.def( @@ -3643,11 +3651,11 @@ void init_ops(nb::module_& m) { biases (array): The biases to use per ``group_size`` elements of ``w`` transpose (bool, optional): Defines whether to multiply with the transposed ``w`` or not, namely whether we are performing - ``x @ w.T`` or ``x @ w``. (default: ``True``) + ``x @ w.T`` or ``x @ w``. Default: ``True``. group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. (default: ``64``) + shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in - ``w``. (default: ``4``) + ``w``. Default: ``4``. Returns: array: The result of the multiplication of ``x`` with ``w``. @@ -3700,9 +3708,9 @@ void init_ops(nb::module_& m) { Args: w (array): Matrix to be quantized group_size (int, optional): The size of the group in ``w`` that shares a - scale and bias. (default: ``64``) + scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element of - ``w`` in the returned quantized matrix. (default: ``4``) + ``w`` in the returned quantized matrix. Default: ``4``. Returns: tuple: A tuple containing @@ -3740,9 +3748,9 @@ void init_ops(nb::module_& m) { scales (array): The scales to use per ``group_size`` elements of ``w`` biases (array): The biases to use per ``group_size`` elements of ``w`` group_size (int, optional): The size of the group in ``w`` that shares a - scale and bias. (default: ``64``) + scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in - ``w``. (default: ``4``) + ``w``. Default: ``4``. Returns: array: The dequantized version of ``w`` @@ -3779,15 +3787,15 @@ void init_ops(nb::module_& m) { w (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w`` biases (array): The biases to use per ``group_size`` elements of ``w`` - lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``) - rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``) + lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. + rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. transpose (bool, optional): Defines whether to multiply with the transposed ``w`` or not, namely whether we are performing - ``x @ w.T`` or ``x @ w``. (default: ``True``) + ``x @ w.T`` or ``x @ w``. Default: ``True``. group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. (default: ``64``) + shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in - ``w``. (default: ``4``) + ``w``. Default: ``4``. Returns: array: The result of the multiplication of ``x`` with ``w`` @@ -3827,7 +3835,7 @@ void init_ops(nb::module_& m) { sum over. If an integer is provided, then sum over the last ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of ``b``. If a list of lists is provided, then sum over the - corresponding dimensions of ``a`` and ``b``. (default: 2) + corresponding dimensions of ``a`` and ``b``. Default: 2. Returns: array: The tensor dot product. @@ -3958,11 +3966,13 @@ void init_ops(nb::module_& m) { Args: a (array): Input array or scalar. b (array): Input array or scalar. - block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``) - mask_out (array, optional): Mask for output (default: ``None``) - mask_lhs (array, optional): Mask for a (default: ``None``) - mask_rhs (array, optional): Mask for b (default: ``None``) + block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``. + mask_out (array, optional): Mask for output. Default: ``None``. + mask_lhs (array, optional): Mask for ``a``. Default: ``None``. + mask_rhs (array, optional): Mask for ``b``. Default: ``None``. + Returns: + array: The output array. )pbdoc"); m.def( "gather_mm", @@ -3996,9 +4006,11 @@ void init_ops(nb::module_& m) { Args: a (array): Input array. b (array): Input array. - lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``) - rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``) + lhs_indices (array, optional): Integer indices for ``a``. Default: ``None`` + rhs_indices (array, optional): Integer indices for ``b``. Default: ``None`` + Returns: + array: The output array. )pbdoc"); m.def( "diagonal", @@ -4406,4 +4418,57 @@ void init_ops(nb::module_& m) { Returns: array: The transformed array. )pbdoc"); + m.def( + "einsum_path", + [](const std::string& equation, const nb::args& operands) { + auto arrays_list = nb::cast>(operands); + auto [path, str] = einsum_path(equation, arrays_list); + // Convert to list of tuples + std::vector tuple_path; + for (auto& p : path) { + tuple_path.push_back(nb::tuple(nb::cast(p))); + } + return std::make_pair(tuple_path, str); + }, + "subscripts"_a, + "operands"_a, + nb::sig("def einsum_path(subscripts: str, *operands)"), + R"pbdoc( + + Compute the contraction order for the given Einstein summation. + + Args: + subscripts (str): The Einstein summation convention equation. + *operands (array): The input arrays. + + Returns: + tuple(list(tuple(int, int)), str): + The einsum path and a string containing information about the + chosen path. + )pbdoc"); + m.def( + "einsum", + [](const std::string& subscripts, + const nb::args& operands, + StreamOrDevice s) { + auto arrays_list = nb::cast>(operands); + return einsum(subscripts, arrays_list, s); + }, + "subscripts"_a, + "operands"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + + Perform the Einstein summation convention on the operands. + + Args: + subscripts (str): The Einstein summation convention equation. + *operands (array): The input arrays. + + Returns: + array: The output array. + )pbdoc"); } diff --git a/python/src/random.cpp b/python/src/random.cpp index 1dd4ef56b..21e242524 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -99,7 +99,7 @@ void init_random(nb::module_& parent_module) { Args: key (array): Input key to split. - num (int, optional): Number of sub keys. Default is 2. + num (int, optional): Number of sub keys. Default: ``2``. Returns: array: The array of sub keys with ``num`` as its first dimension. @@ -137,11 +137,13 @@ void init_random(nb::module_& parent_module) { broadcastable to ``shape``. Args: - low (scalar or array, optional): Lower bound of the distribution. Default is ``0``. - high (scalar or array, optional): Upper bound of the distribution. Default is ``1``. - shape (list(int), optional): Shape of the output. Default is ``()``. + low (scalar or array, optional): Lower bound of the distribution. + Default: ``0``. + high (scalar or array, optional): Upper bound of the distribution. + Default: ``1``. + shape (list(int), optional): Shape of the output. Default:``()``. + dtype (Dtype, optional): Type of the output. Default: ``float32``. key (array, optional): A PRNG key. Default: ``None``. - dtype (Dtype, optional): Type of the output. Default is ``float32``. Returns: array: The output array random values. @@ -250,9 +252,9 @@ void init_random(nb::module_& parent_module) { Args: low (scalar or array): Lower bound of the interval. high (scalar or array): Upper bound of the interval. - shape (list(int), optional): Shape of the output. Defaults to ``()``. - dtype (Dtype, optional): Type of the output. Defaults to ``int32``. - key (array, optional): A PRNG key. Default: None. + shape (list(int), optional): Shape of the output. Default: ``()``. + dtype (Dtype, optional): Type of the output. Default: ``int32``. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The array of random integers. @@ -286,10 +288,10 @@ void init_random(nb::module_& parent_module) { Args: p (float or array, optional): Parameter of the Bernoulli - distribution. Default is 0.5. - shape (list(int), optional): Shape of the output. The default - shape is ``p.shape``. - key (array, optional): A PRNG key. Default: None. + distribution. Default: ``0.5``. + shape (list(int), optional): Shape of the output. + Default: ``p.shape``. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The array of random integers. @@ -331,10 +333,10 @@ void init_random(nb::module_& parent_module) { lower (scalar or array): Lower bound of the domain. upper (scalar or array): Upper bound of the domain. shape (list(int), optional): The shape of the output. - Default is ``()``. + Default:``()``. dtype (Dtype, optional): The data type of the output. - Default is ``float32``. - key (array, optional): A PRNG key. Default: None. + Default: ``float32``. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The output array of random values. @@ -362,7 +364,7 @@ void init_random(nb::module_& parent_module) { Args: shape (list(int)): The shape of the output. - key (array, optional): A PRNG key. Default: None. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The :class:`array` with shape ``shape`` and @@ -407,14 +409,14 @@ void init_random(nb::module_& parent_module) { Args: logits (array): The *unnormalized* categorical distribution(s). axis (int, optional): The axis which specifies the distribution. - Default is ``-1``. + Default: ``-1``. shape (list(int), optional): The shape of the output. This must be broadcast compatable with ``logits.shape`` with the ``axis`` dimension removed. Default: ``None`` num_samples (int, optional): The number of samples to draw from each of the categorical distributions in ``logits``. The output will have ``num_samples`` in the last dimension. Default: ``None``. - key (array, optional): A PRNG key. Default: None. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The ``shape``-sized output array with type ``uint32``. @@ -442,11 +444,12 @@ void init_random(nb::module_& parent_module) { Sample numbers from a Laplace distribution. Args: - shape (list(int), optional): Shape of the output. Default is ``()``. - dtype (Dtype, optional): Type of the output. Default is ``float32``. - loc (float, optional): Mean of the distribution. Default is ``0.0``. - scale (float, optional): The scale "b" of the Laplace distribution. Default is ``1.0``. - key (array, optional): A PRNG key. Default: None. + shape (list(int), optional): Shape of the output. Default: ``()``. + dtype (Dtype, optional): Type of the output. Default: ``float32``. + loc (float, optional): Mean of the distribution. Default: ``0.0``. + scale (float, optional): The scale "b" of the Laplace distribution. + Default:``1.0``. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The output array of random values. diff --git a/python/tests/test_einsum.py b/python/tests/test_einsum.py new file mode 100644 index 000000000..919720c50 --- /dev/null +++ b/python/tests/test_einsum.py @@ -0,0 +1,318 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx_tests +import numpy as np + + +class TestEinsum(mlx_tests.MLXTestCase): + + def test_simple_path(self): + a = mx.zeros((5, 5)) + path = mx.einsum_path("ii", a) + self.assertEqual(path[0], [(0,)]) + + path = mx.einsum_path("ij->i", a) + self.assertEqual(path[0], [(0,)]) + + path = mx.einsum_path("ii->i", a) + self.assertEqual(path[0], [(0,)]) + + a = mx.zeros((5, 8)) + b = mx.zeros((8, 3)) + path = mx.einsum_path("ij,jk", a, b) + self.assertEqual(path[0], [(0, 1)]) + path = mx.einsum_path("ij,jk -> ijk", a, b) + self.assertEqual(path[0], [(0, 1)]) + + a = mx.zeros((5, 8)) + b = mx.zeros((8, 3)) + c = mx.zeros((3, 7)) + path = mx.einsum_path("ij,jk,kl", a, b, c) + + self.assertEqual(path[0], [(0, 1), (0, 1)]) + + a = mx.zeros((5, 8)) + b = mx.zeros((8, 10)) + c = mx.zeros((10, 7)) + path = mx.einsum_path("ij,jk,kl", a, b, c) + self.assertEqual(path[0], [(1, 2), (0, 1)]) + + def test_longer_paths(self): + chars = "abcdefghijklmopqABC" + sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4] + dim_dict = {c: s for c, s in zip(chars, sizes)} + cases = [ + "eb,cb,fb->cef", + "dd,fb,be,cdb->cef", + "dd,fb,be,cdb->cef", + "bca,cdb,dbf,afc->", + "dcc,fce,ea,dbf->ab", + "dcc,fce,ea,dbf->ab", + ] + + for case in cases: + subscripts = case[: case.find("->")].split(",") + inputs = [] + for s in subscripts: + shape = [dim_dict[c] for c in s] + inputs.append(np.ones(shape)) + np_path = np.einsum_path(case, *inputs) + + inputs = [mx.array(i) for i in inputs] + mx_path = mx.einsum_path(case, *inputs) + self.assertEqual(np_path[0][1:], mx_path[0]) + + def test_simple_einsum(self): + a = mx.arange(4 * 4).reshape(4, 4) + a_mx = mx.einsum("ii->i", a) + a_np = np.einsum("ii->i", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 2 * 2).reshape(2, 2, 2) + a_mx = mx.einsum("iii->i", a) + a_np = np.einsum("iii->i", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 2 * 3 * 3).reshape(2, 2, 3, 3) + a_mx = mx.einsum("iijj->ij", a) + a_np = np.einsum("iijj->ij", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3) + a_mx = mx.einsum("ijij->ij", a) + a_np = np.einsum("ijij->ij", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + # Test some simple reductions + a = mx.arange(2 * 2).reshape(2, 2) + a_mx = mx.einsum("ii", a) + a_np = np.einsum("ii", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 4).reshape(2, 4) + a_mx = mx.einsum("ij->", a) + a_np = np.einsum("ij->", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 4).reshape(2, 4) + a_mx = mx.einsum("ij->i", a) + a_np = np.einsum("ij->i", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 4).reshape(2, 4) + a_mx = mx.einsum("ij->j", a) + a_np = np.einsum("ij->j", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 2 * 2).reshape(2, 2, 2) + a_mx = mx.einsum("iii->", a) + a_np = np.einsum("iii->", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3) + a_mx = mx.einsum("ijij->j", a) + a_np = np.einsum("ijij->j", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + # Test some simple transposes + a = mx.arange(2 * 4).reshape(2, 4) + a_mx = mx.einsum("ij", a) + a_np = np.einsum("ij", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 4).reshape(2, 4) + a_mx = mx.einsum("ij->ji", a) + a_np = np.einsum("ij->ji", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.arange(2 * 3 * 4).reshape(2, 3, 4) + a_mx = mx.einsum("ijk->jki", a) + a_np = np.einsum("ijk->jki", a) + self.assertTrue(np.array_equal(a_mx, a_np)) + + def test_two_input_einsum(self): + + # Matmul + a = mx.full((2, 8), 1.0) + b = mx.full((8, 2), 1.0) + a_mx = mx.einsum("ik,kj", a, b) + a_np = np.einsum("ik,kj", a, b) + self.assertTrue(np.array_equal(a_mx, a_np)) + + # Matmul + transpose + a = mx.full((2, 8), 1.0) + b = mx.full((8, 3), 1.0) + a_mx = mx.einsum("ik,kj->ji", a, b) + a_np = np.einsum("ik,kj->ji", a, b) + self.assertTrue(np.array_equal(a_mx, a_np)) + + # Inner product + a = mx.full((4,), 1.0) + b = mx.full((4,), 1.0) + a_mx = mx.einsum("i,i", a, b) + a_np = np.einsum("i,i", a, b) + self.assertTrue(np.array_equal(a_mx, a_np)) + + # Outer product + a = mx.full((4,), 0.5) + b = mx.full((6,), 2.0) + a_mx = mx.einsum("i,j->ij", a, b) + a_np = np.einsum("i,j->ij", a, b) + self.assertTrue(np.array_equal(a_mx, a_np)) + + # Elementwise multiply + a = mx.full((2, 8), 1.0) + b = mx.full((2, 8), 1.0) + a_mx = mx.einsum("ij,ij->ij", a, b) + a_np = np.einsum("ij,ij->ij", a, b) + self.assertTrue(np.array_equal(a_mx, a_np)) + + # Medley + a = mx.full((2, 8, 3, 5), 1.0) + b = mx.full((3, 7, 5, 2), 1.0) + a_mx = mx.einsum("abcd,fgda->bfca", a, b) + a_np = np.einsum("abcd,fgda->bfca", a, b) + self.assertTrue(np.array_equal(a_mx, a_np)) + + def test_sum_first(self): + a = mx.full((5, 8), 1.0) + b = mx.full((8, 2), 1.0) + a_mx = mx.einsum("ab,bc->c", a, b) + a_np = np.einsum("ab,bc->c", a, b) + self.assertTrue(np.array_equal(a_mx, a_np)) + + def test_broadcasting(self): + a = mx.full((5, 1), 1.0) + b = mx.full((8, 2), 1.0) + a_mx = mx.einsum("ab,bc->c", a, b) + return + a_np = np.einsum("ab,bc->c", a, b) + self.assertTrue(np.array_equal(a_mx, a_np)) + + a = mx.random.uniform(shape=(5, 1, 3, 1)) + b = mx.random.uniform(shape=(1, 7, 1, 2)) + a_mx = mx.einsum("abcd,cdab->abcd", a, b) + a_np = np.einsum("abcd,cdab->abcd", a, b) + self.assertTrue(np.allclose(a_mx, a_np)) + + def test_attention(self): + q = mx.random.uniform(shape=(2, 3, 4, 5)) + k = mx.random.uniform(shape=(2, 3, 4, 5)) + v = mx.random.uniform(shape=(2, 3, 4, 5)) + + s = mx.einsum("itjk,iujk->ijtu", q, k) + out_mx = mx.einsum("ijtu,iujk->itjk", s, v) + + s = np.einsum("itjk,iujk->ijtu", q, k) + out_np = np.einsum("ijtu,iujk->itjk", s, v) + + self.assertTrue(np.allclose(out_mx, out_np)) + + def test_multi_input_einsum(self): + a = mx.ones((3, 4, 5)) + out_mx = mx.einsum("ijk,lmk,ijf->lf", a, a, a) + out_np = np.einsum("ijk,lmk,ijf->lf", a, a, a) + self.assertTrue(np.allclose(out_mx, out_np)) + + def test_opt_einsum_test_cases(self): + # Test cases from + # https://github.com/dgasmith/opt_einsum/blob/c826bb7df16f470a69f7bf90598fc27586209d11/opt_einsum/tests/test_contract.py#L11 + tests = [ + # Test hadamard-like products + "a,ab,abc->abc", + "a,b,ab->ab", + # Test index-transformations + "ea,fb,gc,hd,abcd->efgh", + "ea,fb,abcd,gc,hd->efgh", + "abcd,ea,fb,gc,hd->efgh", + # Test complex contractions + "acdf,jbje,gihb,hfac,gfac,gifabc,hfac", + "cd,bdhe,aidb,hgca,gc,hgibcd,hgac", + "abhe,hidj,jgba,hiab,gab", + "bde,cdh,agdb,hica,ibd,hgicd,hiac", + "chd,bde,agbc,hiad,hgc,hgi,hiad", + "chd,bde,agbc,hiad,bdi,cgh,agdb", + "bdhe,acad,hiab,agac,hibd", + # Test collapse + "ab,ab,c->", + "ab,ab,c->c", + "ab,ab,cd,cd->", + "ab,ab,cd,cd->ac", + "ab,ab,cd,cd->cd", + "ab,ab,cd,cd,ef,ef->", + # Test outer prodcuts + "ab,cd,ef->abcdef", + "ab,cd,ef->acdf", + "ab,cd,de->abcde", + "ab,cd,de->be", + "ab,bcd,cd->abcd", + "ab,bcd,cd->abd", + # Random test cases that have previously failed + "eb,cb,fb->cef", + "dd,fb,be,cdb->cef", + "bca,cdb,dbf,afc->", + "dcc,fce,ea,dbf->ab", + "fdf,cdd,ccd,afe->ae", + "abcd,ad", + "ed,fcd,ff,bcf->be", + "baa,dcf,af,cde->be", + "bd,db,eac->ace", + "fff,fae,bef,def->abd", + "efc,dbc,acf,fd->abe", + # Inner products + "ab,ab", + "ab,ba", + "abc,abc", + "abc,bac", + "abc,cba", + # GEMM test cases + "ab,bc", + "ab,cb", + "ba,bc", + "ba,cb", + "abcd,cd", + "abcd,ab", + "abcd,cdef", + "abcd,cdef->feba", + "abcd,efdc", + # Inner then dot + "aab,bc->ac", + "ab,bcc->ac", + "aab,bcc->ac", + "baa,bcc->ac", + "aab,ccb->ac", + # Randomly build test caes + "aab,fa,df,ecc->bde", + "ecb,fef,bad,ed->ac", + "bcf,bbb,fbf,fc->", + "bb,ff,be->e", + "bcb,bb,fc,fff->", + "fbb,dfd,fc,fc->", + "afd,ba,cc,dc->bf", + "adb,bc,fa,cfc->d", + "bbd,bda,fc,db->acf", + "dba,ead,cad->bce", + "aef,fbc,dca->bde", + ] + + size_dict = dict(zip("abcdefghij", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3])) + + def inputs_for_case(test_case): + inputs = test_case.split("->")[0].split(",") + return [ + mx.random.uniform(shape=tuple(size_dict[c] for c in inp)) + for inp in inputs + ] + + for test_case in tests: + inputs = inputs_for_case(test_case) + np_out = np.einsum(test_case, *inputs) + mx_out = mx.einsum(test_case, *inputs) + self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9ba77474b..42bf66580 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -26,6 +26,7 @@ target_sources(tests PRIVATE custom_vjp_tests.cpp creations_tests.cpp device_tests.cpp + einsum_tests.cpp eval_tests.cpp fft_tests.cpp load_tests.cpp diff --git a/tests/einsum_tests.cpp b/tests/einsum_tests.cpp new file mode 100644 index 000000000..834d533cf --- /dev/null +++ b/tests/einsum_tests.cpp @@ -0,0 +1,76 @@ +// Copyright © 2024 Apple Inc. + +#include "doctest/doctest.h" +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test einsum path") { + std::vector> expected = {{1, 2}, {0, 1}}; + auto path = + einsum_path("ij,jk,kl", {ones({2, 2}), ones({2, 4}), ones({4, 2})}).first; + CHECK_EQ(path, expected); + + expected = {{0}}; + path = einsum_path("jki", {ones({2, 3, 4})}).first; + CHECK_EQ(path, expected); + + expected = {{0, 1}}; + path = einsum_path("i,i", {ones({2}), ones({1})}).first; + CHECK_EQ(path, expected); + + expected = {{0, 1}}; + path = einsum_path("ij,jk", {ones({2, 2}), ones({2, 2})}).first; + CHECK_EQ(path, expected); + + expected = {{0, 1}}; + path = einsum_path("ijk,jil->kl", {ones({3, 4, 5}), ones({4, 3, 2})}).first; + CHECK_EQ(path, expected); + + expected = {{0, 3}, {1, 3}, {0, 2}, {0, 1}}; + path = einsum_path( + "ijk,ilm,njm,nlk,abc->", + {ones({2, 6, 8}), + ones({2, 4, 5}), + ones({3, 6, 5}), + ones({3, 4, 8}), + ones({9, 4, 7})}) + .first; + CHECK_EQ(path, expected); + + expected = {{0, 2}, {0, 3}, {0, 2}, {0, 1}}; + path = einsum_path( + "ea,fb,abcd,gc,hd->efgh", + {ones({10, 10}), + ones({10, 10}), + ones({10, 10, 10, 10}), + ones({10, 10}), + ones({10, 10})}) + .first; + CHECK_EQ(path, expected); +} + +TEST_CASE("test einsum") { + CHECK_THROWS(einsum("i,j", {array({1.0})})); + CHECK_THROWS(einsum("ijk", {full({2, 2}, 2.0f)})); + CHECK_THROWS(einsum("", {})); + CHECK_THROWS(einsum("ij", {array({1, 2})})); + CHECK_THROWS(einsum("", {array({1, 2})})); + CHECK_THROWS(einsum("i,ij", {array({1, 2}), array({2, 3})})); + CHECK_THROWS(einsum("i,i", {array({1, 2}), array({2, 3, 4})})); + CHECK_THROWS(einsum("i->ii", {array({1, 2})})); + CHECK_THROWS(einsum("12", {zeros({4, 4})})); + CHECK_THROWS(einsum("ii->i", {zeros({3, 2})})); + + auto x = einsum("jki", {full({2, 3, 4}, 3.0f)}); + auto expected = full({4, 2, 3}, 3.0f); + CHECK_EQ(allclose(x, expected).item(), true); + + x = einsum("ij,jk->ik", {full({2, 2}, 2.0f), full({2, 2}, 3.0f)}); + expected = array({12.0f, 12.0f, 12.0f, 12.0f}, {2, 2}); + CHECK_EQ(allclose(x, expected).item(), true); + + x = einsum("i,j->ij", {full({2}, 15.0f), full({4}, 20.0f)}); + expected = full({2, 4}, 300.0f); + CHECK_EQ(allclose(x, expected).item(), true); +}