mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Einsum (#1269)
* einsum initial * fix comma break * sum axis was wrong * small cleanups * python binding * changed bindings to resemble numpy * remove todo comment * comment changes * add count of operands/inputs * fail fast if operands list is empty * ignore comma if no output * einsum path matching numpy * getting somewhere with path * remove print * it passes the first test * moved einsum tests to seperate file * seperated einsum path * moved einsum naive * remove space from equation * fast fail if no operands passed * update tests and remove printf * small cleanup * some more cleanups * removed python helper file * ack * utilize std for finding min in vector * duplicate def * remove the tuple as it was unreadable * moved einsum_naive back to ops * remaining isn't needed * avoid creating another set * cleanup * greedy path, start of naive einsum * more einsum * fix some bugs * some more fixes, tests pass * benchmark * some simplify * fix einsum and test Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> * add a bunch more tests and fix a bunch more bugs * some docs nits --------- Co-authored-by: dc-dc-dc <dgcruz983@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
7f914365fd
commit
baf9fa5f42
@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||||
|
84
benchmarks/python/einsum_bench.py
Normal file
84
benchmarks/python/einsum_bench.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def timeit(fn, its=100, args=[]):
|
||||||
|
for _ in range(5):
|
||||||
|
fn(*args)
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(its):
|
||||||
|
fn(*args)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
return 1e3 * (toc - tic) / its
|
||||||
|
|
||||||
|
|
||||||
|
def time_little_einsum_path():
|
||||||
|
subscripts = "ik,kj->ij"
|
||||||
|
x = mx.ones((32, 32))
|
||||||
|
y = mx.ones((32, 32))
|
||||||
|
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
||||||
|
|
||||||
|
x = np.array(x)
|
||||||
|
y = np.array(y)
|
||||||
|
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
||||||
|
print("Timing little einsum path...")
|
||||||
|
print(f"MLX ... {mx_time:.3f} ms")
|
||||||
|
print(f"NumPy... {np_time:.3f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
def time_big_einsum_path():
|
||||||
|
chars = list("abcdefgh")
|
||||||
|
char_to_dim = {c: v for v, c in enumerate(chars)}
|
||||||
|
|
||||||
|
num_inputs = 10
|
||||||
|
inputs = []
|
||||||
|
subscripts = []
|
||||||
|
for _ in range(num_inputs):
|
||||||
|
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
||||||
|
subscripts.append("".join(subscript))
|
||||||
|
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
||||||
|
subscripts = ",".join(subscripts)
|
||||||
|
|
||||||
|
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
||||||
|
|
||||||
|
inputs = [mx.array(x) for x in inputs]
|
||||||
|
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
||||||
|
print("Timing big einsum path...")
|
||||||
|
print(f"MLX ... {mx_time:.3f} ms")
|
||||||
|
print(f"NumPy... {np_time:.3f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
def time_attention():
|
||||||
|
def regular_attention(x):
|
||||||
|
# shape [batch, sequence, num_heads, head_dim]
|
||||||
|
queries, keys, values = x, x, x
|
||||||
|
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
||||||
|
mx.eval(output)
|
||||||
|
|
||||||
|
def einsum_attention(x):
|
||||||
|
# shape [batch, sequence, num_heads, head_dim]
|
||||||
|
queries, keys, values = x, x, x
|
||||||
|
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
||||||
|
mx.eval(output)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
||||||
|
|
||||||
|
regular_time = timeit(regular_attention, args=(x,))
|
||||||
|
ein_time = timeit(einsum_attention, args=(x,))
|
||||||
|
print("Timing einsum attention...")
|
||||||
|
print(f"Regular ... {regular_time:.3f} ms")
|
||||||
|
print(f"Einsum ... {ein_time:.3f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_little_einsum_path()
|
||||||
|
time_big_einsum_path()
|
||||||
|
time_attention()
|
@ -57,6 +57,8 @@ Operations
|
|||||||
diagonal
|
diagonal
|
||||||
divide
|
divide
|
||||||
divmod
|
divmod
|
||||||
|
einsum
|
||||||
|
einsum_path
|
||||||
equal
|
equal
|
||||||
erf
|
erf
|
||||||
erfinv
|
erfinv
|
||||||
|
@ -6,6 +6,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
|
859
mlx/einsum.cpp
Normal file
859
mlx/einsum.cpp
Normal file
@ -0,0 +1,859 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "mlx/einsum.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// The MLX einsum implementation is based on NumPy (which is based on
|
||||||
|
// opt_einsum):
|
||||||
|
// https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743
|
||||||
|
// https://github.com/dgasmith/opt_einsum
|
||||||
|
|
||||||
|
using CharSet = std::unordered_set<char>;
|
||||||
|
|
||||||
|
// A helper struct to hold the string and set
|
||||||
|
// representation of a subscript to avoid needing
|
||||||
|
// to recompute the set
|
||||||
|
struct Subscript {
|
||||||
|
Subscript(std::string str, CharSet set)
|
||||||
|
: str(std::move(str)), set(std::move(set)) {};
|
||||||
|
std::string str;
|
||||||
|
CharSet set;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PathInfo {
|
||||||
|
size_t naive_cost;
|
||||||
|
size_t naive_scaling;
|
||||||
|
size_t optimized_cost;
|
||||||
|
size_t optimized_scaling;
|
||||||
|
size_t largest_term;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PathNode {
|
||||||
|
PathNode(
|
||||||
|
std::vector<Subscript> inputs,
|
||||||
|
Subscript output,
|
||||||
|
std::vector<int> positions)
|
||||||
|
: inputs(std::move(inputs)),
|
||||||
|
output(std::move(output)),
|
||||||
|
positions(std::move(positions)) {};
|
||||||
|
|
||||||
|
std::vector<Subscript> inputs;
|
||||||
|
Subscript output;
|
||||||
|
|
||||||
|
std::vector<int> positions;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse the comma separated subscripts into a vector of strings. If the
|
||||||
|
// output subscripts are missing they are inferred.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
// "ij,jk -> ik" becomes {{"ij", "jk"}, "ik"}
|
||||||
|
// "ij,jk" becomes {{"ij", "jk"}, "ik"}
|
||||||
|
std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
|
||||||
|
std::string lhs, rhs;
|
||||||
|
|
||||||
|
// Start by removing all white space
|
||||||
|
subscripts.erase(
|
||||||
|
std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end());
|
||||||
|
|
||||||
|
if (auto pos = subscripts.find("->"); pos != std::string::npos) {
|
||||||
|
// Explicit mode
|
||||||
|
lhs = subscripts.substr(0, pos);
|
||||||
|
rhs = subscripts.substr(pos + 2);
|
||||||
|
} else {
|
||||||
|
// Implicit mode:
|
||||||
|
// - repeats are summed
|
||||||
|
// - remaining output axes are ordered alphabetically
|
||||||
|
lhs = subscripts;
|
||||||
|
std::unordered_map<char, int> temp;
|
||||||
|
for (auto& c : subscripts) {
|
||||||
|
if (c == ',') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto inserted = temp.insert({c, 0});
|
||||||
|
inserted.first->second++;
|
||||||
|
}
|
||||||
|
for (auto& k : temp) {
|
||||||
|
if (k.second == 1) {
|
||||||
|
rhs += k.first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::sort(rhs.begin(), rhs.end());
|
||||||
|
}
|
||||||
|
std::vector<std::string> input_list;
|
||||||
|
std::stringstream ss(lhs);
|
||||||
|
std::string token;
|
||||||
|
while (getline(ss, token, ',')) {
|
||||||
|
input_list.push_back(token);
|
||||||
|
}
|
||||||
|
return {input_list, rhs};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if two sets are disjoint
|
||||||
|
bool disjoint(const CharSet& x, const CharSet& y) {
|
||||||
|
for (auto& c : x) {
|
||||||
|
if (y.find(c) != y.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
size_t term_size(const T& term, std::unordered_map<char, int> dict) {
|
||||||
|
size_t size = 1;
|
||||||
|
for (auto c : term) {
|
||||||
|
size *= dict[c];
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t flop_count(
|
||||||
|
const CharSet& term,
|
||||||
|
bool inner,
|
||||||
|
int num_terms,
|
||||||
|
std::unordered_map<char, int> dict) {
|
||||||
|
size_t size = term_size(term, dict);
|
||||||
|
auto op_factor = 1;
|
||||||
|
if ((num_terms - 1) > op_factor) {
|
||||||
|
op_factor = num_terms - 1;
|
||||||
|
}
|
||||||
|
if (inner) {
|
||||||
|
op_factor += 1;
|
||||||
|
}
|
||||||
|
return size * op_factor;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<size_t, int> compute_cost_and_scaling(
|
||||||
|
const std::vector<Subscript>& inputs,
|
||||||
|
const Subscript& output,
|
||||||
|
std::unordered_map<char, int> dim_map) {
|
||||||
|
CharSet contractions;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
contractions.insert(in.set.begin(), in.set.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool inner = false;
|
||||||
|
for (auto c : contractions) {
|
||||||
|
if (output.set.find(c) == output.set.end()) {
|
||||||
|
inner = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto cost = flop_count(contractions, inner, inputs.size(), dim_map);
|
||||||
|
return {cost, contractions.size()};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
|
||||||
|
std::vector<Subscript> inputs,
|
||||||
|
const Subscript& output,
|
||||||
|
std::unordered_map<char, int> dim_map,
|
||||||
|
size_t cost_limit,
|
||||||
|
size_t memory_limit) {
|
||||||
|
// Helper struct for building the greedy path
|
||||||
|
struct Contraction {
|
||||||
|
Contraction(
|
||||||
|
size_t size,
|
||||||
|
size_t cost,
|
||||||
|
CharSet output,
|
||||||
|
int dims,
|
||||||
|
int x,
|
||||||
|
int y)
|
||||||
|
: size(size),
|
||||||
|
cost(cost),
|
||||||
|
output(std::move(output)),
|
||||||
|
dims(dims),
|
||||||
|
x(x),
|
||||||
|
y(y) {};
|
||||||
|
|
||||||
|
int64_t size; // Size difference, can be negative
|
||||||
|
size_t cost;
|
||||||
|
CharSet output;
|
||||||
|
int dims; // Number of dimensions in the contraction
|
||||||
|
int x;
|
||||||
|
int y;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start by iterating over all possible combinations
|
||||||
|
std::vector<std::pair<int, int>> pos_pairs;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
for (int j = i + 1; j < inputs.size(); ++j) {
|
||||||
|
pos_pairs.emplace_back(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<PathNode> path;
|
||||||
|
std::vector<Contraction> possible_contractions;
|
||||||
|
size_t path_cost = 0;
|
||||||
|
int path_scaling = 0;
|
||||||
|
auto num_in = inputs.size();
|
||||||
|
for (int i = 0; i < num_in - 1; ++i) {
|
||||||
|
auto add_contraction = [&](int p1, int p2) {
|
||||||
|
CharSet new_term;
|
||||||
|
CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end());
|
||||||
|
contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end());
|
||||||
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
|
if (i == p1 || i == p2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto& in = inputs[i].set;
|
||||||
|
for (auto c : in) {
|
||||||
|
if (contractions.find(c) != contractions.end()) {
|
||||||
|
new_term.insert(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto c : output.set) {
|
||||||
|
if (contractions.find(c) != contractions.end()) {
|
||||||
|
new_term.insert(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore if:
|
||||||
|
// - The size of the new result is greater than the memory limit
|
||||||
|
// - The cost is larger than the naive cost
|
||||||
|
auto new_size = term_size(new_term, dim_map);
|
||||||
|
if (new_size > memory_limit) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int64_t removed_size = term_size(inputs[p1].set, dim_map) +
|
||||||
|
term_size(inputs[p2].set, dim_map) - new_size;
|
||||||
|
|
||||||
|
bool inner = contractions.size() > new_term.size();
|
||||||
|
auto cost = flop_count(contractions, inner, 2, dim_map);
|
||||||
|
if (path_cost + cost > cost_limit) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
possible_contractions.emplace_back(
|
||||||
|
removed_size, cost, std::move(new_term), contractions.size(), p1, p2);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& [p1, p2] : pos_pairs) {
|
||||||
|
// Ignore outer products
|
||||||
|
if (!disjoint(inputs[p1].set, inputs[p2].set)) {
|
||||||
|
add_contraction(p1, p2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's nothing in the contraction list,
|
||||||
|
// go over the pairs again without ignoring outer products
|
||||||
|
if (possible_contractions.empty()) {
|
||||||
|
for (auto& [p1, p2] : pos_pairs) {
|
||||||
|
add_contraction(p1, p2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (possible_contractions.empty()) {
|
||||||
|
// Default to naive einsum for the remaining inputs
|
||||||
|
std::vector<int> positions(inputs.size());
|
||||||
|
std::iota(positions.begin(), positions.end(), 0);
|
||||||
|
auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map);
|
||||||
|
path.emplace_back(std::move(inputs), output, std::move(positions));
|
||||||
|
|
||||||
|
path_cost += cost;
|
||||||
|
path_scaling = std::max(scale, path_scaling);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the best contraction
|
||||||
|
auto& best = *std::min_element(
|
||||||
|
possible_contractions.begin(),
|
||||||
|
possible_contractions.end(),
|
||||||
|
[](const auto& x, const auto& y) {
|
||||||
|
return x.size > y.size || (x.size == y.size && x.cost < y.cost);
|
||||||
|
});
|
||||||
|
path_scaling = std::max(best.dims, path_scaling);
|
||||||
|
|
||||||
|
// Construct the output subscripts
|
||||||
|
std::string out_str(best.output.begin(), best.output.end());
|
||||||
|
// TODO, sorting by dimension size seems suboptimal?
|
||||||
|
std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) {
|
||||||
|
return dim_map[x] < dim_map[y];
|
||||||
|
});
|
||||||
|
Subscript new_output(std::move(out_str), std::move(best.output));
|
||||||
|
|
||||||
|
// Add the chosen contraction to the path
|
||||||
|
{
|
||||||
|
std::vector<Subscript> in_terms;
|
||||||
|
in_terms.push_back(std::move(inputs[best.x]));
|
||||||
|
in_terms.push_back(std::move(inputs[best.y]));
|
||||||
|
path.emplace_back(
|
||||||
|
std::move(in_terms), new_output, std::vector<int>{best.x, best.y});
|
||||||
|
}
|
||||||
|
// Remove used terms
|
||||||
|
inputs.erase(inputs.begin() + best.y);
|
||||||
|
inputs.erase(inputs.begin() + best.x);
|
||||||
|
|
||||||
|
// Add the new result
|
||||||
|
inputs.push_back(std::move(new_output));
|
||||||
|
|
||||||
|
// Update the existing contractions based on the selected one
|
||||||
|
std::vector<Contraction> updated_contractions;
|
||||||
|
for (auto& contraction : possible_contractions) {
|
||||||
|
// Drop contractions which contain either selected term
|
||||||
|
if (contraction.x == best.x || contraction.x == best.y ||
|
||||||
|
contraction.y == best.x || contraction.y == best.y) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the positions of other contractions
|
||||||
|
int x =
|
||||||
|
contraction.x - (contraction.x > best.x) - (contraction.x > best.y);
|
||||||
|
int y =
|
||||||
|
contraction.y - (contraction.y > best.x) - (contraction.y > best.y);
|
||||||
|
contraction.x = x;
|
||||||
|
contraction.y = y;
|
||||||
|
updated_contractions.push_back(std::move(contraction));
|
||||||
|
}
|
||||||
|
|
||||||
|
pos_pairs.clear();
|
||||||
|
for (int i = 0; i < inputs.size() - 1; ++i) {
|
||||||
|
pos_pairs.emplace_back(i, inputs.size() - 1);
|
||||||
|
}
|
||||||
|
path_cost += best.cost;
|
||||||
|
|
||||||
|
possible_contractions = std::move(updated_contractions);
|
||||||
|
}
|
||||||
|
return {path, path_cost, path_scaling};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assumes inputs have already have had repeats and single axis sums collapsed
|
||||||
|
bool can_dot(const std::vector<Subscript>& inputs, const Subscript& output) {
|
||||||
|
if (inputs.size() != 2) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto c : inputs[0].set) {
|
||||||
|
// Use batched tensordot if anything is being contracted
|
||||||
|
if (output.set.find(c) == output.set.end()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
array batch_tensordot(
|
||||||
|
array a,
|
||||||
|
array b,
|
||||||
|
std::vector<int> a_contract,
|
||||||
|
std::vector<int> a_batch,
|
||||||
|
std::vector<int> a_concat,
|
||||||
|
std::vector<int> b_contract,
|
||||||
|
std::vector<int> b_batch,
|
||||||
|
std::vector<int> b_concat,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
// Broadcast contracting dimensions
|
||||||
|
{
|
||||||
|
auto a_shape = a.shape();
|
||||||
|
auto b_shape = b.shape();
|
||||||
|
for (int i = 0; i < a_contract.size(); ++i) {
|
||||||
|
auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i]));
|
||||||
|
a_shape[a_contract[i]] = d;
|
||||||
|
b_shape[b_contract[i]] = d;
|
||||||
|
}
|
||||||
|
a = broadcast_to(a, a_shape, s);
|
||||||
|
b = broadcast_to(b, b_shape, s);
|
||||||
|
}
|
||||||
|
auto transpose_reshape = [&s](
|
||||||
|
const array& x,
|
||||||
|
const std::vector<int>& i,
|
||||||
|
const std::vector<int>& j,
|
||||||
|
const std::vector<int>& k) {
|
||||||
|
std::vector<int> reorder(i.begin(), i.end());
|
||||||
|
reorder.insert(reorder.end(), j.begin(), j.end());
|
||||||
|
reorder.insert(reorder.end(), k.begin(), k.end());
|
||||||
|
|
||||||
|
int size1 = 1;
|
||||||
|
for (auto s : j) {
|
||||||
|
size1 *= x.shape(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
int size2 = 1;
|
||||||
|
for (auto s : k) {
|
||||||
|
size2 *= x.shape(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> shape;
|
||||||
|
for (auto ax : i) {
|
||||||
|
shape.push_back(x.shape(ax));
|
||||||
|
}
|
||||||
|
shape.push_back(size1);
|
||||||
|
shape.push_back(size2);
|
||||||
|
|
||||||
|
return reshape(transpose(x, reorder, s), std::move(shape), s);
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<int> out_shape;
|
||||||
|
for (auto ax : a_batch) {
|
||||||
|
out_shape.push_back(a.shape(ax));
|
||||||
|
}
|
||||||
|
for (auto ax : a_concat) {
|
||||||
|
out_shape.push_back(a.shape(ax));
|
||||||
|
}
|
||||||
|
for (auto ax : b_concat) {
|
||||||
|
out_shape.push_back(b.shape(ax));
|
||||||
|
}
|
||||||
|
|
||||||
|
a = transpose_reshape(a, a_batch, a_concat, a_contract);
|
||||||
|
b = transpose_reshape(b, b_batch, b_contract, b_concat);
|
||||||
|
|
||||||
|
return reshape(matmul(a, b, s), std::move(out_shape), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collapse repeated subscripts and return the resulting array. The subscript
|
||||||
|
// is also updated in place. For example:
|
||||||
|
// - Given an input with shape (4, 4) and subscript "ii", returns
|
||||||
|
// the diagonal of shape (4,) and updates the subscript to "i".
|
||||||
|
// - Given an input with shape (4, 2, 4, 2) and subscript "ijij",
|
||||||
|
// returns an output with shape (4, 2) and updates the subscript
|
||||||
|
// to "ij".
|
||||||
|
array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
|
||||||
|
// Build a list of (repeat chars, num repeats)
|
||||||
|
auto& str = subscript.str;
|
||||||
|
std::vector<std::pair<char, int>> repeats;
|
||||||
|
std::string new_str;
|
||||||
|
{
|
||||||
|
std::string repeat_str;
|
||||||
|
std::string no_repeat_str;
|
||||||
|
std::unordered_map<char, int> counts;
|
||||||
|
for (int i = 0; i < str.size(); ++i) {
|
||||||
|
auto [it, _] = counts.insert({str[i], 0});
|
||||||
|
it->second++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& v : counts) {
|
||||||
|
if (v.second > 1) {
|
||||||
|
repeats.emplace_back(v.first, v.second);
|
||||||
|
repeat_str += v.first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& c : str) {
|
||||||
|
if (counts[c] == 1) {
|
||||||
|
no_repeat_str += c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new_str = repeat_str + no_repeat_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the inputs for gather
|
||||||
|
auto slice_sizes = in.shape();
|
||||||
|
std::vector<int> axes;
|
||||||
|
std::vector<array> indices;
|
||||||
|
int n_expand = repeats.size();
|
||||||
|
for (auto [c, v] : repeats) {
|
||||||
|
for (int i = 0; i < str.size(); ++i) {
|
||||||
|
if (str[i] == c) {
|
||||||
|
slice_sizes[i] = 1;
|
||||||
|
axes.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<int> idx_shape(n_expand--, 1);
|
||||||
|
idx_shape[0] = in.shape(axes.back());
|
||||||
|
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
|
||||||
|
for (int i = 0; i < v; ++i) {
|
||||||
|
indices.push_back(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
in = gather(in, indices, axes, slice_sizes, s);
|
||||||
|
|
||||||
|
// Update subscript string with removed dups
|
||||||
|
str = new_str;
|
||||||
|
|
||||||
|
// Squeeze singleton dimensions left over from the gather
|
||||||
|
for (auto& ax : axes) {
|
||||||
|
ax += indices[0].ndim();
|
||||||
|
}
|
||||||
|
|
||||||
|
return squeeze(in, axes, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collapse repeat indices and sum single dimensions.
|
||||||
|
// For example:
|
||||||
|
// - "aa" becomes "a"
|
||||||
|
// - "ij,jk->k" becoms "j,jk->k"
|
||||||
|
void preprocess_einsum_inputs(
|
||||||
|
std::vector<Subscript>& inputs,
|
||||||
|
const Subscript& output,
|
||||||
|
const std::vector<int>& positions,
|
||||||
|
std::vector<array>& operands,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
// Collapse repeat indices
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
auto& in = inputs[i];
|
||||||
|
if (in.set.size() < in.str.size()) {
|
||||||
|
operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum indices that are only in a single input
|
||||||
|
{
|
||||||
|
std::unordered_map<char, int> counts;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
for (auto c : in.set) {
|
||||||
|
auto inserted = counts.insert({c, 0});
|
||||||
|
inserted.first->second++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto c : output.set) {
|
||||||
|
auto inserted = counts.insert({c, 0});
|
||||||
|
inserted.first->second++;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
auto& in = inputs[i];
|
||||||
|
std::vector<int> sum_axes;
|
||||||
|
for (int ax = 0; ax < in.str.size(); ++ax) {
|
||||||
|
if (counts[in.str[ax]] == 1) {
|
||||||
|
sum_axes.push_back(ax);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!sum_axes.empty()) {
|
||||||
|
operands[positions[i]] =
|
||||||
|
sum(operands[positions[i]], sum_axes, false, s);
|
||||||
|
}
|
||||||
|
for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) {
|
||||||
|
in.set.erase(in.str[*it]);
|
||||||
|
in.str.erase(in.str.begin() + *it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array einsum_naive(
|
||||||
|
std::vector<Subscript> inputs,
|
||||||
|
const Subscript& output,
|
||||||
|
const std::vector<int>& positions,
|
||||||
|
std::vector<array> operands,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
// Map each character to an axis
|
||||||
|
std::unordered_map<char, int> char_to_ax;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
for (auto c : in.str) {
|
||||||
|
char_to_ax.insert({c, char_to_ax.size()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand and transpose inputs as needed
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
int pos = positions[i];
|
||||||
|
auto& op = operands[pos];
|
||||||
|
|
||||||
|
// Add missing dimensions at the end
|
||||||
|
if (op.ndim() != char_to_ax.size()) {
|
||||||
|
auto shape = op.shape();
|
||||||
|
shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1);
|
||||||
|
op = reshape(op, std::move(shape), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transpose:
|
||||||
|
// - Build a vector of (char, ax) pairs for the current input
|
||||||
|
// - Sort the vector by the canonical axis in char_to_ax
|
||||||
|
// - Extract the sorted axis to get transpose order
|
||||||
|
std::vector<std::pair<char, int>> str_ax;
|
||||||
|
for (auto c : inputs[i].str) {
|
||||||
|
str_ax.emplace_back(c, str_ax.size());
|
||||||
|
}
|
||||||
|
for (auto [c, ax] : char_to_ax) {
|
||||||
|
if (inputs[i].set.find(c) == inputs[i].set.end()) {
|
||||||
|
str_ax.emplace_back(c, str_ax.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::sort(
|
||||||
|
str_ax.begin(),
|
||||||
|
str_ax.end(),
|
||||||
|
[&char_to_ax](const auto& x, const auto& y) {
|
||||||
|
return char_to_ax[x.first] < char_to_ax[y.first];
|
||||||
|
});
|
||||||
|
|
||||||
|
// Skip the transpose if not needed
|
||||||
|
if (std::is_sorted(
|
||||||
|
str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) {
|
||||||
|
return x.second < y.second;
|
||||||
|
})) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> reorder;
|
||||||
|
for (auto [c, ax] : str_ax) {
|
||||||
|
reorder.push_back(ax);
|
||||||
|
}
|
||||||
|
op = transpose(op, reorder, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply and sum
|
||||||
|
auto out = operands[positions[0]];
|
||||||
|
for (int i = 1; i < positions.size(); ++i) {
|
||||||
|
out = multiply(out, operands[positions[i]], s);
|
||||||
|
}
|
||||||
|
std::vector<int> sum_axes;
|
||||||
|
for (auto [c, ax] : char_to_ax) {
|
||||||
|
if (output.set.find(c) == output.set.end()) {
|
||||||
|
sum_axes.push_back(ax);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!sum_axes.empty()) {
|
||||||
|
out = sum(out, sum_axes, false, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transpose output if needed
|
||||||
|
std::vector<int> reorder;
|
||||||
|
for (auto c : output.str) {
|
||||||
|
reorder.push_back(char_to_ax[c]);
|
||||||
|
}
|
||||||
|
for (auto& r : reorder) {
|
||||||
|
int offset = 0;
|
||||||
|
for (auto s : sum_axes) {
|
||||||
|
if (r > s) {
|
||||||
|
offset++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r -= offset;
|
||||||
|
}
|
||||||
|
return transpose(out, reorder, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
||||||
|
const std::string& subscripts,
|
||||||
|
const std::vector<array>& operands,
|
||||||
|
const std::string& fn_name) {
|
||||||
|
if (operands.size() == 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name << "] At least one operand is required.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [in_subscripts, out_subscript] = parse(subscripts);
|
||||||
|
|
||||||
|
if (operands.size() != in_subscripts.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name << "] Number of operands, " << operands.size()
|
||||||
|
<< ", does not match number of input subscripts, "
|
||||||
|
<< in_subscripts.size();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto check_letters = [&](const auto& subscript) {
|
||||||
|
for (auto c : subscript) {
|
||||||
|
if (!isalpha(c)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name << "] Subscripts must be letters, but got '" << c
|
||||||
|
<< "'.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (auto& in : in_subscripts) {
|
||||||
|
check_letters(in);
|
||||||
|
}
|
||||||
|
check_letters(out_subscript);
|
||||||
|
|
||||||
|
CharSet out_set(out_subscript.begin(), out_subscript.end());
|
||||||
|
if (out_set.size() != out_subscript.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name << "] Repeat indices not allowed in output.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
Subscript output(out_subscript, std::move(out_set));
|
||||||
|
|
||||||
|
std::unordered_map<char, int> dim_map;
|
||||||
|
std::vector<Subscript> inputs;
|
||||||
|
for (int i = 0; i < in_subscripts.size(); ++i) {
|
||||||
|
auto& in = in_subscripts[i];
|
||||||
|
CharSet in_set(in.begin(), in.end());
|
||||||
|
inputs.emplace_back(in, in_set);
|
||||||
|
|
||||||
|
if (in.size() != operands[i].ndim()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name << "] Invalid number of subscripts " << in.size()
|
||||||
|
<< " for input " << i << " with " << operands[i].ndim()
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check repeat subscripts are valid
|
||||||
|
if (in_set.size() < in.size()) {
|
||||||
|
std::unordered_map<char, int> local_dims;
|
||||||
|
for (int j = 0; j < in.size(); ++j) {
|
||||||
|
auto dim = operands[i].shape(j);
|
||||||
|
auto inserted = local_dims.insert({in[j], dim});
|
||||||
|
if (!inserted.second) {
|
||||||
|
if (inserted.first->second != dim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name << "] Dimensions of repeated subscripts "
|
||||||
|
<< "do not have the same size (" << inserted.first->second
|
||||||
|
<< " != " << dim << ").";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < in.size(); j++) {
|
||||||
|
auto c = in[j];
|
||||||
|
auto dim = operands[i].shape(j);
|
||||||
|
auto inserted = dim_map.insert({c, dim});
|
||||||
|
auto& in_dim = inserted.first->second;
|
||||||
|
if (dim != 1 && in_dim != 1 && in_dim != dim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name << "] Cannot broadcast dimension " << j
|
||||||
|
<< " of input " << i << " with shape " << operands[i].shape()
|
||||||
|
<< " to size " << in_dim << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
// Ensure the broadcasted size is used
|
||||||
|
in_dim = std::max(in_dim, dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t max_size = term_size(out_subscript, dim_map);
|
||||||
|
for (auto& in : in_subscripts) {
|
||||||
|
max_size = std::max(max_size, term_size(in, dim_map));
|
||||||
|
}
|
||||||
|
|
||||||
|
PathInfo path_info;
|
||||||
|
|
||||||
|
// Get the full naive cost
|
||||||
|
std::tie(path_info.naive_cost, path_info.naive_scaling) =
|
||||||
|
compute_cost_and_scaling(inputs, output, dim_map);
|
||||||
|
|
||||||
|
// Calculate the path
|
||||||
|
std::vector<PathNode> path;
|
||||||
|
if (inputs.size() <= 2) {
|
||||||
|
std::vector<int> positions(in_subscripts.size());
|
||||||
|
std::iota(positions.begin(), positions.end(), 0);
|
||||||
|
path.emplace_back(
|
||||||
|
std::move(inputs), std::move(output), std::move(positions));
|
||||||
|
} else {
|
||||||
|
std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) =
|
||||||
|
greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size);
|
||||||
|
// Set the final output subscript to the actual output
|
||||||
|
path.back().output = std::move(output);
|
||||||
|
}
|
||||||
|
return {path, path_info};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
|
||||||
|
const std::string& subscripts,
|
||||||
|
const std::vector<array>& operands) {
|
||||||
|
auto [path, path_info] =
|
||||||
|
einsum_path_helper(subscripts, operands, "einsum_path");
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> pos_path;
|
||||||
|
for (auto& p : path) {
|
||||||
|
pos_path.push_back(p.positions);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream path_print;
|
||||||
|
path_print << " Complete contraction: " << subscripts << "\n"
|
||||||
|
<< " Naive scaling: " << path_info.naive_scaling << "\n"
|
||||||
|
<< " Optimized scaling: " << path_info.optimized_scaling
|
||||||
|
<< "\n"
|
||||||
|
<< " Naive FLOP count: " << path_info.naive_cost << "\n"
|
||||||
|
<< " Optimized FLOP count: " << path_info.optimized_cost << "\n";
|
||||||
|
// TODO add more info here
|
||||||
|
return {pos_path, path_print.str()};
|
||||||
|
}
|
||||||
|
|
||||||
|
array einsum(
|
||||||
|
const std::string& subscripts,
|
||||||
|
const std::vector<array>& operands,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum");
|
||||||
|
auto inputs = operands;
|
||||||
|
for (auto& node : path) {
|
||||||
|
preprocess_einsum_inputs(
|
||||||
|
node.inputs, node.output, node.positions, inputs, s);
|
||||||
|
|
||||||
|
if (can_dot(node.inputs, node.output)) {
|
||||||
|
auto& in_a = node.inputs[0];
|
||||||
|
auto& in_b = node.inputs[1];
|
||||||
|
auto& out = node.output;
|
||||||
|
|
||||||
|
std::vector<int> a_contract;
|
||||||
|
std::vector<int> a_batch;
|
||||||
|
std::vector<int> a_concat;
|
||||||
|
for (int i = 0; i < in_a.str.size(); ++i) {
|
||||||
|
auto c = in_a.str[i];
|
||||||
|
if (out.set.find(c) == out.set.end()) {
|
||||||
|
// Not in the output, contraction
|
||||||
|
a_contract.push_back(i);
|
||||||
|
} else if (in_b.set.find(c) != in_b.set.end()) {
|
||||||
|
// Not a contraction but in both inputs, batch dim
|
||||||
|
a_batch.push_back(i);
|
||||||
|
} else {
|
||||||
|
// Not a batch dim or contract dim, so concat dim
|
||||||
|
a_concat.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> b_contract;
|
||||||
|
std::vector<int> b_batch;
|
||||||
|
std::vector<int> b_concat;
|
||||||
|
for (auto a_i : a_contract) {
|
||||||
|
b_contract.push_back(in_b.str.find(in_a.str[a_i]));
|
||||||
|
}
|
||||||
|
for (auto a_i : a_batch) {
|
||||||
|
b_batch.push_back(in_b.str.find(in_a.str[a_i]));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < in_b.str.size(); ++i) {
|
||||||
|
auto c = in_b.str[i];
|
||||||
|
if (out.set.find(c) != out.set.end() &&
|
||||||
|
in_a.set.find(c) == in_a.set.end()) {
|
||||||
|
b_concat.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& a = inputs[node.positions[0]];
|
||||||
|
auto& b = inputs[node.positions[1]];
|
||||||
|
|
||||||
|
std::unordered_map<char, int> char_map;
|
||||||
|
for (auto i : a_batch) {
|
||||||
|
char_map.insert({in_a.str[i], char_map.size()});
|
||||||
|
}
|
||||||
|
for (auto i : a_concat) {
|
||||||
|
char_map.insert({in_a.str[i], char_map.size()});
|
||||||
|
}
|
||||||
|
for (auto i : b_concat) {
|
||||||
|
char_map.insert({in_b.str[i], char_map.size()});
|
||||||
|
}
|
||||||
|
inputs.emplace_back(batch_tensordot(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
std::move(a_contract),
|
||||||
|
std::move(a_batch),
|
||||||
|
std::move(a_concat),
|
||||||
|
std::move(b_contract),
|
||||||
|
std::move(b_batch),
|
||||||
|
std::move(b_concat),
|
||||||
|
s));
|
||||||
|
|
||||||
|
std::vector<int> reorder;
|
||||||
|
for (auto c : node.output.str) {
|
||||||
|
reorder.push_back(char_map[c]);
|
||||||
|
}
|
||||||
|
inputs.back() = transpose(inputs.back(), reorder, s);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
inputs.emplace_back(
|
||||||
|
einsum_naive(node.inputs, node.output, node.positions, inputs, s));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Positions are always sorted increasing, so start from the back
|
||||||
|
for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) {
|
||||||
|
inputs.erase(inputs.begin() + *it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inputs.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
22
mlx/einsum.h
Normal file
22
mlx/einsum.h
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
|
||||||
|
const std::string& subscripts,
|
||||||
|
const std::vector<array>& operands);
|
||||||
|
|
||||||
|
array einsum(
|
||||||
|
const std::string& subscripts,
|
||||||
|
const std::vector<array>& operands,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -8,6 +8,7 @@
|
|||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
|
#include "mlx/einsum.h"
|
||||||
#include "mlx/fast.h"
|
#include "mlx/fast.h"
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/io.h"
|
#include "mlx/io.h"
|
||||||
|
@ -444,8 +444,9 @@ array laplace(
|
|||||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||||
// Use inverse CDF to generate Laplacian noise
|
// Use inverse CDF to generate Laplacian noise
|
||||||
samples = multiply(
|
samples = multiply(
|
||||||
sign(samples),
|
sign(samples, stream),
|
||||||
log1p(multiply(array(-1.0f, dtype), abs(samples))),
|
log1p(
|
||||||
|
multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream),
|
||||||
stream);
|
stream);
|
||||||
|
|
||||||
if (scale != 1.0) {
|
if (scale != 1.0) {
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
|
#include "mlx/einsum.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
#include "python/src/load.h"
|
#include "python/src/load.h"
|
||||||
@ -40,15 +41,6 @@ double scalar_to_double(Scalar s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void init_ops(nb::module_& m) {
|
void init_ops(nb::module_& m) {
|
||||||
// TODO, remove deprecation errors in a future release
|
|
||||||
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
|
|
||||||
});
|
|
||||||
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
|
|
||||||
});
|
|
||||||
m.def(
|
m.def(
|
||||||
"reshape",
|
"reshape",
|
||||||
&reshape,
|
&reshape,
|
||||||
@ -1238,7 +1230,8 @@ void init_ops(nb::module_& m) {
|
|||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The unchanged input ``a`` but without gradient flowing
|
array:
|
||||||
|
The unchanged input ``a`` but without gradient flowing
|
||||||
through it.
|
through it.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
@ -2936,6 +2929,9 @@ void init_ops(nb::module_& m) {
|
|||||||
reverse (bool): Perform the cumulative sum in reverse.
|
reverse (bool): Perform the cumulative sum in reverse.
|
||||||
inclusive (bool): The i-th element of the output includes the i-th
|
inclusive (bool): The i-th element of the output includes the i-th
|
||||||
element of the input.
|
element of the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"cumprod",
|
"cumprod",
|
||||||
@ -2969,6 +2965,9 @@ void init_ops(nb::module_& m) {
|
|||||||
reverse (bool): Perform the cumulative product in reverse.
|
reverse (bool): Perform the cumulative product in reverse.
|
||||||
inclusive (bool): The i-th element of the output includes the i-th
|
inclusive (bool): The i-th element of the output includes the i-th
|
||||||
element of the input.
|
element of the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"cummax",
|
"cummax",
|
||||||
@ -3002,6 +3001,9 @@ void init_ops(nb::module_& m) {
|
|||||||
reverse (bool): Perform the cumulative maximum in reverse.
|
reverse (bool): Perform the cumulative maximum in reverse.
|
||||||
inclusive (bool): The i-th element of the output includes the i-th
|
inclusive (bool): The i-th element of the output includes the i-th
|
||||||
element of the input.
|
element of the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"cummin",
|
"cummin",
|
||||||
@ -3035,6 +3037,9 @@ void init_ops(nb::module_& m) {
|
|||||||
reverse (bool): Perform the cumulative minimum in reverse.
|
reverse (bool): Perform the cumulative minimum in reverse.
|
||||||
inclusive (bool): The i-th element of the output includes the i-th
|
inclusive (bool): The i-th element of the output includes the i-th
|
||||||
element of the input.
|
element of the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"conj",
|
"conj",
|
||||||
@ -3052,6 +3057,9 @@ void init_ops(nb::module_& m) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array
|
a (array): Input array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"conjugate",
|
"conjugate",
|
||||||
@ -3069,6 +3077,9 @@ void init_ops(nb::module_& m) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array
|
a (array): Input array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"convolve",
|
"convolve",
|
||||||
@ -3492,14 +3503,11 @@ void init_ops(nb::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved.
|
file (file, str): File in which the array is saved.
|
||||||
format (str, optional): Format of the file. If ``None``, the
|
format (str, optional): Format of the file. If ``None``, the
|
||||||
format
|
format is inferred from the file extension. Supported formats:
|
||||||
is inferred from the file extension. Supported formats:
|
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
|
||||||
``npy``,
|
|
||||||
``npz``, and ``safetensors``. Default: ``None``.
|
|
||||||
return_metadata (bool, optional): Load the metadata for formats
|
return_metadata (bool, optional): Load the metadata for formats
|
||||||
which
|
which support matadata. The metadata will be returned as an
|
||||||
support matadata. The metadata will be returned as an
|
additional dictionary. Default: ``False``.
|
||||||
additional dictionary.
|
|
||||||
Returns:
|
Returns:
|
||||||
array or dict:
|
array or dict:
|
||||||
A single array if loading from a ``.npy`` file or a dict
|
A single array if loading from a ``.npy`` file or a dict
|
||||||
@ -3551,9 +3559,9 @@ void init_ops(nb::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved.
|
file (file, str): File in which the array is saved.
|
||||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||||
be saved. metadata (dict(str, Union[array, str, list(str)])):
|
be saved.
|
||||||
The dictionary of
|
metadata (dict(str, Union[array, str, list(str)])): The dictionary
|
||||||
metadata to be saved. The values can be a scalar or 1D
|
of metadata to be saved. The values can be a scalar or 1D
|
||||||
obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.
|
obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
@ -3643,11 +3651,11 @@ void init_ops(nb::module_& m) {
|
|||||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||||
transpose (bool, optional): Defines whether to multiply with the
|
transpose (bool, optional): Defines whether to multiply with the
|
||||||
transposed ``w`` or not, namely whether we are performing
|
transposed ``w`` or not, namely whether we are performing
|
||||||
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||||
group_size (int, optional): The size of the group in ``w`` that
|
group_size (int, optional): The size of the group in ``w`` that
|
||||||
shares a scale and bias. (default: ``64``)
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. (default: ``4``)
|
``w``. Default: ``4``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``.
|
array: The result of the multiplication of ``x`` with ``w``.
|
||||||
@ -3700,9 +3708,9 @@ void init_ops(nb::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
w (array): Matrix to be quantized
|
w (array): Matrix to be quantized
|
||||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||||
scale and bias. (default: ``64``)
|
scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element of
|
bits (int, optional): The number of bits occupied by each element of
|
||||||
``w`` in the returned quantized matrix. (default: ``4``)
|
``w`` in the returned quantized matrix. Default: ``4``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: A tuple containing
|
tuple: A tuple containing
|
||||||
@ -3740,9 +3748,9 @@ void init_ops(nb::module_& m) {
|
|||||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||||
scale and bias. (default: ``64``)
|
scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. (default: ``4``)
|
``w``. Default: ``4``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The dequantized version of ``w``
|
array: The dequantized version of ``w``
|
||||||
@ -3779,15 +3787,15 @@ void init_ops(nb::module_& m) {
|
|||||||
w (array): Quantized matrix packed in unsigned integers
|
w (array): Quantized matrix packed in unsigned integers
|
||||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||||
lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``)
|
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
||||||
rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``)
|
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
||||||
transpose (bool, optional): Defines whether to multiply with the
|
transpose (bool, optional): Defines whether to multiply with the
|
||||||
transposed ``w`` or not, namely whether we are performing
|
transposed ``w`` or not, namely whether we are performing
|
||||||
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||||
group_size (int, optional): The size of the group in ``w`` that
|
group_size (int, optional): The size of the group in ``w`` that
|
||||||
shares a scale and bias. (default: ``64``)
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. (default: ``4``)
|
``w``. Default: ``4``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``
|
array: The result of the multiplication of ``x`` with ``w``
|
||||||
@ -3827,7 +3835,7 @@ void init_ops(nb::module_& m) {
|
|||||||
sum over. If an integer is provided, then sum over the last
|
sum over. If an integer is provided, then sum over the last
|
||||||
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
||||||
``b``. If a list of lists is provided, then sum over the
|
``b``. If a list of lists is provided, then sum over the
|
||||||
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
corresponding dimensions of ``a`` and ``b``. Default: 2.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The tensor dot product.
|
array: The tensor dot product.
|
||||||
@ -3958,11 +3966,13 @@ void init_ops(nb::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
a (array): Input array or scalar.
|
a (array): Input array or scalar.
|
||||||
b (array): Input array or scalar.
|
b (array): Input array or scalar.
|
||||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``.
|
||||||
mask_out (array, optional): Mask for output (default: ``None``)
|
mask_out (array, optional): Mask for output. Default: ``None``.
|
||||||
mask_lhs (array, optional): Mask for a (default: ``None``)
|
mask_lhs (array, optional): Mask for ``a``. Default: ``None``.
|
||||||
mask_rhs (array, optional): Mask for b (default: ``None``)
|
mask_rhs (array, optional): Mask for ``b``. Default: ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"gather_mm",
|
"gather_mm",
|
||||||
@ -3996,9 +4006,11 @@ void init_ops(nb::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
b (array): Input array.
|
b (array): Input array.
|
||||||
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
|
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
|
||||||
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
|
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"diagonal",
|
"diagonal",
|
||||||
@ -4406,4 +4418,57 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The transformed array.
|
array: The transformed array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"einsum_path",
|
||||||
|
[](const std::string& equation, const nb::args& operands) {
|
||||||
|
auto arrays_list = nb::cast<std::vector<array>>(operands);
|
||||||
|
auto [path, str] = einsum_path(equation, arrays_list);
|
||||||
|
// Convert to list of tuples
|
||||||
|
std::vector<nb::tuple> tuple_path;
|
||||||
|
for (auto& p : path) {
|
||||||
|
tuple_path.push_back(nb::tuple(nb::cast(p)));
|
||||||
|
}
|
||||||
|
return std::make_pair(tuple_path, str);
|
||||||
|
},
|
||||||
|
"subscripts"_a,
|
||||||
|
"operands"_a,
|
||||||
|
nb::sig("def einsum_path(subscripts: str, *operands)"),
|
||||||
|
R"pbdoc(
|
||||||
|
|
||||||
|
Compute the contraction order for the given Einstein summation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscripts (str): The Einstein summation convention equation.
|
||||||
|
*operands (array): The input arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(list(tuple(int, int)), str):
|
||||||
|
The einsum path and a string containing information about the
|
||||||
|
chosen path.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"einsum",
|
||||||
|
[](const std::string& subscripts,
|
||||||
|
const nb::args& operands,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto arrays_list = nb::cast<std::vector<array>>(operands);
|
||||||
|
return einsum(subscripts, arrays_list, s);
|
||||||
|
},
|
||||||
|
"subscripts"_a,
|
||||||
|
"operands"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
|
||||||
|
Perform the Einstein summation convention on the operands.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscripts (str): The Einstein summation convention equation.
|
||||||
|
*operands (array): The input arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -99,7 +99,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
key (array): Input key to split.
|
key (array): Input key to split.
|
||||||
num (int, optional): Number of sub keys. Default is 2.
|
num (int, optional): Number of sub keys. Default: ``2``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The array of sub keys with ``num`` as its first dimension.
|
array: The array of sub keys with ``num`` as its first dimension.
|
||||||
@ -137,11 +137,13 @@ void init_random(nb::module_& parent_module) {
|
|||||||
broadcastable to ``shape``.
|
broadcastable to ``shape``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
|
low (scalar or array, optional): Lower bound of the distribution.
|
||||||
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
|
Default: ``0``.
|
||||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
high (scalar or array, optional): Upper bound of the distribution.
|
||||||
|
Default: ``1``.
|
||||||
|
shape (list(int), optional): Shape of the output. Default:``()``.
|
||||||
|
dtype (Dtype, optional): Type of the output. Default: ``float32``.
|
||||||
key (array, optional): A PRNG key. Default: ``None``.
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array random values.
|
array: The output array random values.
|
||||||
@ -250,9 +252,9 @@ void init_random(nb::module_& parent_module) {
|
|||||||
Args:
|
Args:
|
||||||
low (scalar or array): Lower bound of the interval.
|
low (scalar or array): Lower bound of the interval.
|
||||||
high (scalar or array): Upper bound of the interval.
|
high (scalar or array): Upper bound of the interval.
|
||||||
shape (list(int), optional): Shape of the output. Defaults to ``()``.
|
shape (list(int), optional): Shape of the output. Default: ``()``.
|
||||||
dtype (Dtype, optional): Type of the output. Defaults to ``int32``.
|
dtype (Dtype, optional): Type of the output. Default: ``int32``.
|
||||||
key (array, optional): A PRNG key. Default: None.
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The array of random integers.
|
array: The array of random integers.
|
||||||
@ -286,10 +288,10 @@ void init_random(nb::module_& parent_module) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
p (float or array, optional): Parameter of the Bernoulli
|
p (float or array, optional): Parameter of the Bernoulli
|
||||||
distribution. Default is 0.5.
|
distribution. Default: ``0.5``.
|
||||||
shape (list(int), optional): Shape of the output. The default
|
shape (list(int), optional): Shape of the output.
|
||||||
shape is ``p.shape``.
|
Default: ``p.shape``.
|
||||||
key (array, optional): A PRNG key. Default: None.
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The array of random integers.
|
array: The array of random integers.
|
||||||
@ -331,10 +333,10 @@ void init_random(nb::module_& parent_module) {
|
|||||||
lower (scalar or array): Lower bound of the domain.
|
lower (scalar or array): Lower bound of the domain.
|
||||||
upper (scalar or array): Upper bound of the domain.
|
upper (scalar or array): Upper bound of the domain.
|
||||||
shape (list(int), optional): The shape of the output.
|
shape (list(int), optional): The shape of the output.
|
||||||
Default is ``()``.
|
Default:``()``.
|
||||||
dtype (Dtype, optional): The data type of the output.
|
dtype (Dtype, optional): The data type of the output.
|
||||||
Default is ``float32``.
|
Default: ``float32``.
|
||||||
key (array, optional): A PRNG key. Default: None.
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array of random values.
|
array: The output array of random values.
|
||||||
@ -362,7 +364,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (list(int)): The shape of the output.
|
shape (list(int)): The shape of the output.
|
||||||
key (array, optional): A PRNG key. Default: None.
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The :class:`array` with shape ``shape`` and
|
array: The :class:`array` with shape ``shape`` and
|
||||||
@ -407,14 +409,14 @@ void init_random(nb::module_& parent_module) {
|
|||||||
Args:
|
Args:
|
||||||
logits (array): The *unnormalized* categorical distribution(s).
|
logits (array): The *unnormalized* categorical distribution(s).
|
||||||
axis (int, optional): The axis which specifies the distribution.
|
axis (int, optional): The axis which specifies the distribution.
|
||||||
Default is ``-1``.
|
Default: ``-1``.
|
||||||
shape (list(int), optional): The shape of the output. This must
|
shape (list(int), optional): The shape of the output. This must
|
||||||
be broadcast compatable with ``logits.shape`` with the ``axis``
|
be broadcast compatable with ``logits.shape`` with the ``axis``
|
||||||
dimension removed. Default: ``None``
|
dimension removed. Default: ``None``
|
||||||
num_samples (int, optional): The number of samples to draw from each
|
num_samples (int, optional): The number of samples to draw from each
|
||||||
of the categorical distributions in ``logits``. The output will have
|
of the categorical distributions in ``logits``. The output will have
|
||||||
``num_samples`` in the last dimension. Default: ``None``.
|
``num_samples`` in the last dimension. Default: ``None``.
|
||||||
key (array, optional): A PRNG key. Default: None.
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The ``shape``-sized output array with type ``uint32``.
|
array: The ``shape``-sized output array with type ``uint32``.
|
||||||
@ -442,11 +444,12 @@ void init_random(nb::module_& parent_module) {
|
|||||||
Sample numbers from a Laplace distribution.
|
Sample numbers from a Laplace distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
shape (list(int), optional): Shape of the output. Default: ``()``.
|
||||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
dtype (Dtype, optional): Type of the output. Default: ``float32``.
|
||||||
loc (float, optional): Mean of the distribution. Default is ``0.0``.
|
loc (float, optional): Mean of the distribution. Default: ``0.0``.
|
||||||
scale (float, optional): The scale "b" of the Laplace distribution. Default is ``1.0``.
|
scale (float, optional): The scale "b" of the Laplace distribution.
|
||||||
key (array, optional): A PRNG key. Default: None.
|
Default:``1.0``.
|
||||||
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array of random values.
|
array: The output array of random values.
|
||||||
|
318
python/tests/test_einsum.py
Normal file
318
python/tests/test_einsum.py
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class TestEinsum(mlx_tests.MLXTestCase):
|
||||||
|
|
||||||
|
def test_simple_path(self):
|
||||||
|
a = mx.zeros((5, 5))
|
||||||
|
path = mx.einsum_path("ii", a)
|
||||||
|
self.assertEqual(path[0], [(0,)])
|
||||||
|
|
||||||
|
path = mx.einsum_path("ij->i", a)
|
||||||
|
self.assertEqual(path[0], [(0,)])
|
||||||
|
|
||||||
|
path = mx.einsum_path("ii->i", a)
|
||||||
|
self.assertEqual(path[0], [(0,)])
|
||||||
|
|
||||||
|
a = mx.zeros((5, 8))
|
||||||
|
b = mx.zeros((8, 3))
|
||||||
|
path = mx.einsum_path("ij,jk", a, b)
|
||||||
|
self.assertEqual(path[0], [(0, 1)])
|
||||||
|
path = mx.einsum_path("ij,jk -> ijk", a, b)
|
||||||
|
self.assertEqual(path[0], [(0, 1)])
|
||||||
|
|
||||||
|
a = mx.zeros((5, 8))
|
||||||
|
b = mx.zeros((8, 3))
|
||||||
|
c = mx.zeros((3, 7))
|
||||||
|
path = mx.einsum_path("ij,jk,kl", a, b, c)
|
||||||
|
|
||||||
|
self.assertEqual(path[0], [(0, 1), (0, 1)])
|
||||||
|
|
||||||
|
a = mx.zeros((5, 8))
|
||||||
|
b = mx.zeros((8, 10))
|
||||||
|
c = mx.zeros((10, 7))
|
||||||
|
path = mx.einsum_path("ij,jk,kl", a, b, c)
|
||||||
|
self.assertEqual(path[0], [(1, 2), (0, 1)])
|
||||||
|
|
||||||
|
def test_longer_paths(self):
|
||||||
|
chars = "abcdefghijklmopqABC"
|
||||||
|
sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
|
||||||
|
dim_dict = {c: s for c, s in zip(chars, sizes)}
|
||||||
|
cases = [
|
||||||
|
"eb,cb,fb->cef",
|
||||||
|
"dd,fb,be,cdb->cef",
|
||||||
|
"dd,fb,be,cdb->cef",
|
||||||
|
"bca,cdb,dbf,afc->",
|
||||||
|
"dcc,fce,ea,dbf->ab",
|
||||||
|
"dcc,fce,ea,dbf->ab",
|
||||||
|
]
|
||||||
|
|
||||||
|
for case in cases:
|
||||||
|
subscripts = case[: case.find("->")].split(",")
|
||||||
|
inputs = []
|
||||||
|
for s in subscripts:
|
||||||
|
shape = [dim_dict[c] for c in s]
|
||||||
|
inputs.append(np.ones(shape))
|
||||||
|
np_path = np.einsum_path(case, *inputs)
|
||||||
|
|
||||||
|
inputs = [mx.array(i) for i in inputs]
|
||||||
|
mx_path = mx.einsum_path(case, *inputs)
|
||||||
|
self.assertEqual(np_path[0][1:], mx_path[0])
|
||||||
|
|
||||||
|
def test_simple_einsum(self):
|
||||||
|
a = mx.arange(4 * 4).reshape(4, 4)
|
||||||
|
a_mx = mx.einsum("ii->i", a)
|
||||||
|
a_np = np.einsum("ii->i", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 2 * 2).reshape(2, 2, 2)
|
||||||
|
a_mx = mx.einsum("iii->i", a)
|
||||||
|
a_np = np.einsum("iii->i", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 2 * 3 * 3).reshape(2, 2, 3, 3)
|
||||||
|
a_mx = mx.einsum("iijj->ij", a)
|
||||||
|
a_np = np.einsum("iijj->ij", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3)
|
||||||
|
a_mx = mx.einsum("ijij->ij", a)
|
||||||
|
a_np = np.einsum("ijij->ij", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
# Test some simple reductions
|
||||||
|
a = mx.arange(2 * 2).reshape(2, 2)
|
||||||
|
a_mx = mx.einsum("ii", a)
|
||||||
|
a_np = np.einsum("ii", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 4).reshape(2, 4)
|
||||||
|
a_mx = mx.einsum("ij->", a)
|
||||||
|
a_np = np.einsum("ij->", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 4).reshape(2, 4)
|
||||||
|
a_mx = mx.einsum("ij->i", a)
|
||||||
|
a_np = np.einsum("ij->i", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 4).reshape(2, 4)
|
||||||
|
a_mx = mx.einsum("ij->j", a)
|
||||||
|
a_np = np.einsum("ij->j", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 2 * 2).reshape(2, 2, 2)
|
||||||
|
a_mx = mx.einsum("iii->", a)
|
||||||
|
a_np = np.einsum("iii->", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3)
|
||||||
|
a_mx = mx.einsum("ijij->j", a)
|
||||||
|
a_np = np.einsum("ijij->j", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
# Test some simple transposes
|
||||||
|
a = mx.arange(2 * 4).reshape(2, 4)
|
||||||
|
a_mx = mx.einsum("ij", a)
|
||||||
|
a_np = np.einsum("ij", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 4).reshape(2, 4)
|
||||||
|
a_mx = mx.einsum("ij->ji", a)
|
||||||
|
a_np = np.einsum("ij->ji", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.arange(2 * 3 * 4).reshape(2, 3, 4)
|
||||||
|
a_mx = mx.einsum("ijk->jki", a)
|
||||||
|
a_np = np.einsum("ijk->jki", a)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
def test_two_input_einsum(self):
|
||||||
|
|
||||||
|
# Matmul
|
||||||
|
a = mx.full((2, 8), 1.0)
|
||||||
|
b = mx.full((8, 2), 1.0)
|
||||||
|
a_mx = mx.einsum("ik,kj", a, b)
|
||||||
|
a_np = np.einsum("ik,kj", a, b)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
# Matmul + transpose
|
||||||
|
a = mx.full((2, 8), 1.0)
|
||||||
|
b = mx.full((8, 3), 1.0)
|
||||||
|
a_mx = mx.einsum("ik,kj->ji", a, b)
|
||||||
|
a_np = np.einsum("ik,kj->ji", a, b)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
# Inner product
|
||||||
|
a = mx.full((4,), 1.0)
|
||||||
|
b = mx.full((4,), 1.0)
|
||||||
|
a_mx = mx.einsum("i,i", a, b)
|
||||||
|
a_np = np.einsum("i,i", a, b)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
# Outer product
|
||||||
|
a = mx.full((4,), 0.5)
|
||||||
|
b = mx.full((6,), 2.0)
|
||||||
|
a_mx = mx.einsum("i,j->ij", a, b)
|
||||||
|
a_np = np.einsum("i,j->ij", a, b)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
# Elementwise multiply
|
||||||
|
a = mx.full((2, 8), 1.0)
|
||||||
|
b = mx.full((2, 8), 1.0)
|
||||||
|
a_mx = mx.einsum("ij,ij->ij", a, b)
|
||||||
|
a_np = np.einsum("ij,ij->ij", a, b)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
# Medley
|
||||||
|
a = mx.full((2, 8, 3, 5), 1.0)
|
||||||
|
b = mx.full((3, 7, 5, 2), 1.0)
|
||||||
|
a_mx = mx.einsum("abcd,fgda->bfca", a, b)
|
||||||
|
a_np = np.einsum("abcd,fgda->bfca", a, b)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
def test_sum_first(self):
|
||||||
|
a = mx.full((5, 8), 1.0)
|
||||||
|
b = mx.full((8, 2), 1.0)
|
||||||
|
a_mx = mx.einsum("ab,bc->c", a, b)
|
||||||
|
a_np = np.einsum("ab,bc->c", a, b)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
def test_broadcasting(self):
|
||||||
|
a = mx.full((5, 1), 1.0)
|
||||||
|
b = mx.full((8, 2), 1.0)
|
||||||
|
a_mx = mx.einsum("ab,bc->c", a, b)
|
||||||
|
return
|
||||||
|
a_np = np.einsum("ab,bc->c", a, b)
|
||||||
|
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||||
|
|
||||||
|
a = mx.random.uniform(shape=(5, 1, 3, 1))
|
||||||
|
b = mx.random.uniform(shape=(1, 7, 1, 2))
|
||||||
|
a_mx = mx.einsum("abcd,cdab->abcd", a, b)
|
||||||
|
a_np = np.einsum("abcd,cdab->abcd", a, b)
|
||||||
|
self.assertTrue(np.allclose(a_mx, a_np))
|
||||||
|
|
||||||
|
def test_attention(self):
|
||||||
|
q = mx.random.uniform(shape=(2, 3, 4, 5))
|
||||||
|
k = mx.random.uniform(shape=(2, 3, 4, 5))
|
||||||
|
v = mx.random.uniform(shape=(2, 3, 4, 5))
|
||||||
|
|
||||||
|
s = mx.einsum("itjk,iujk->ijtu", q, k)
|
||||||
|
out_mx = mx.einsum("ijtu,iujk->itjk", s, v)
|
||||||
|
|
||||||
|
s = np.einsum("itjk,iujk->ijtu", q, k)
|
||||||
|
out_np = np.einsum("ijtu,iujk->itjk", s, v)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(out_mx, out_np))
|
||||||
|
|
||||||
|
def test_multi_input_einsum(self):
|
||||||
|
a = mx.ones((3, 4, 5))
|
||||||
|
out_mx = mx.einsum("ijk,lmk,ijf->lf", a, a, a)
|
||||||
|
out_np = np.einsum("ijk,lmk,ijf->lf", a, a, a)
|
||||||
|
self.assertTrue(np.allclose(out_mx, out_np))
|
||||||
|
|
||||||
|
def test_opt_einsum_test_cases(self):
|
||||||
|
# Test cases from
|
||||||
|
# https://github.com/dgasmith/opt_einsum/blob/c826bb7df16f470a69f7bf90598fc27586209d11/opt_einsum/tests/test_contract.py#L11
|
||||||
|
tests = [
|
||||||
|
# Test hadamard-like products
|
||||||
|
"a,ab,abc->abc",
|
||||||
|
"a,b,ab->ab",
|
||||||
|
# Test index-transformations
|
||||||
|
"ea,fb,gc,hd,abcd->efgh",
|
||||||
|
"ea,fb,abcd,gc,hd->efgh",
|
||||||
|
"abcd,ea,fb,gc,hd->efgh",
|
||||||
|
# Test complex contractions
|
||||||
|
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
|
||||||
|
"cd,bdhe,aidb,hgca,gc,hgibcd,hgac",
|
||||||
|
"abhe,hidj,jgba,hiab,gab",
|
||||||
|
"bde,cdh,agdb,hica,ibd,hgicd,hiac",
|
||||||
|
"chd,bde,agbc,hiad,hgc,hgi,hiad",
|
||||||
|
"chd,bde,agbc,hiad,bdi,cgh,agdb",
|
||||||
|
"bdhe,acad,hiab,agac,hibd",
|
||||||
|
# Test collapse
|
||||||
|
"ab,ab,c->",
|
||||||
|
"ab,ab,c->c",
|
||||||
|
"ab,ab,cd,cd->",
|
||||||
|
"ab,ab,cd,cd->ac",
|
||||||
|
"ab,ab,cd,cd->cd",
|
||||||
|
"ab,ab,cd,cd,ef,ef->",
|
||||||
|
# Test outer prodcuts
|
||||||
|
"ab,cd,ef->abcdef",
|
||||||
|
"ab,cd,ef->acdf",
|
||||||
|
"ab,cd,de->abcde",
|
||||||
|
"ab,cd,de->be",
|
||||||
|
"ab,bcd,cd->abcd",
|
||||||
|
"ab,bcd,cd->abd",
|
||||||
|
# Random test cases that have previously failed
|
||||||
|
"eb,cb,fb->cef",
|
||||||
|
"dd,fb,be,cdb->cef",
|
||||||
|
"bca,cdb,dbf,afc->",
|
||||||
|
"dcc,fce,ea,dbf->ab",
|
||||||
|
"fdf,cdd,ccd,afe->ae",
|
||||||
|
"abcd,ad",
|
||||||
|
"ed,fcd,ff,bcf->be",
|
||||||
|
"baa,dcf,af,cde->be",
|
||||||
|
"bd,db,eac->ace",
|
||||||
|
"fff,fae,bef,def->abd",
|
||||||
|
"efc,dbc,acf,fd->abe",
|
||||||
|
# Inner products
|
||||||
|
"ab,ab",
|
||||||
|
"ab,ba",
|
||||||
|
"abc,abc",
|
||||||
|
"abc,bac",
|
||||||
|
"abc,cba",
|
||||||
|
# GEMM test cases
|
||||||
|
"ab,bc",
|
||||||
|
"ab,cb",
|
||||||
|
"ba,bc",
|
||||||
|
"ba,cb",
|
||||||
|
"abcd,cd",
|
||||||
|
"abcd,ab",
|
||||||
|
"abcd,cdef",
|
||||||
|
"abcd,cdef->feba",
|
||||||
|
"abcd,efdc",
|
||||||
|
# Inner then dot
|
||||||
|
"aab,bc->ac",
|
||||||
|
"ab,bcc->ac",
|
||||||
|
"aab,bcc->ac",
|
||||||
|
"baa,bcc->ac",
|
||||||
|
"aab,ccb->ac",
|
||||||
|
# Randomly build test caes
|
||||||
|
"aab,fa,df,ecc->bde",
|
||||||
|
"ecb,fef,bad,ed->ac",
|
||||||
|
"bcf,bbb,fbf,fc->",
|
||||||
|
"bb,ff,be->e",
|
||||||
|
"bcb,bb,fc,fff->",
|
||||||
|
"fbb,dfd,fc,fc->",
|
||||||
|
"afd,ba,cc,dc->bf",
|
||||||
|
"adb,bc,fa,cfc->d",
|
||||||
|
"bbd,bda,fc,db->acf",
|
||||||
|
"dba,ead,cad->bce",
|
||||||
|
"aef,fbc,dca->bde",
|
||||||
|
]
|
||||||
|
|
||||||
|
size_dict = dict(zip("abcdefghij", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3]))
|
||||||
|
|
||||||
|
def inputs_for_case(test_case):
|
||||||
|
inputs = test_case.split("->")[0].split(",")
|
||||||
|
return [
|
||||||
|
mx.random.uniform(shape=tuple(size_dict[c] for c in inp))
|
||||||
|
for inp in inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_case in tests:
|
||||||
|
inputs = inputs_for_case(test_case)
|
||||||
|
np_out = np.einsum(test_case, *inputs)
|
||||||
|
mx_out = mx.einsum(test_case, *inputs)
|
||||||
|
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -26,6 +26,7 @@ target_sources(tests PRIVATE
|
|||||||
custom_vjp_tests.cpp
|
custom_vjp_tests.cpp
|
||||||
creations_tests.cpp
|
creations_tests.cpp
|
||||||
device_tests.cpp
|
device_tests.cpp
|
||||||
|
einsum_tests.cpp
|
||||||
eval_tests.cpp
|
eval_tests.cpp
|
||||||
fft_tests.cpp
|
fft_tests.cpp
|
||||||
load_tests.cpp
|
load_tests.cpp
|
||||||
|
76
tests/einsum_tests.cpp
Normal file
76
tests/einsum_tests.cpp
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test einsum path") {
|
||||||
|
std::vector<std::vector<int>> expected = {{1, 2}, {0, 1}};
|
||||||
|
auto path =
|
||||||
|
einsum_path("ij,jk,kl", {ones({2, 2}), ones({2, 4}), ones({4, 2})}).first;
|
||||||
|
CHECK_EQ(path, expected);
|
||||||
|
|
||||||
|
expected = {{0}};
|
||||||
|
path = einsum_path("jki", {ones({2, 3, 4})}).first;
|
||||||
|
CHECK_EQ(path, expected);
|
||||||
|
|
||||||
|
expected = {{0, 1}};
|
||||||
|
path = einsum_path("i,i", {ones({2}), ones({1})}).first;
|
||||||
|
CHECK_EQ(path, expected);
|
||||||
|
|
||||||
|
expected = {{0, 1}};
|
||||||
|
path = einsum_path("ij,jk", {ones({2, 2}), ones({2, 2})}).first;
|
||||||
|
CHECK_EQ(path, expected);
|
||||||
|
|
||||||
|
expected = {{0, 1}};
|
||||||
|
path = einsum_path("ijk,jil->kl", {ones({3, 4, 5}), ones({4, 3, 2})}).first;
|
||||||
|
CHECK_EQ(path, expected);
|
||||||
|
|
||||||
|
expected = {{0, 3}, {1, 3}, {0, 2}, {0, 1}};
|
||||||
|
path = einsum_path(
|
||||||
|
"ijk,ilm,njm,nlk,abc->",
|
||||||
|
{ones({2, 6, 8}),
|
||||||
|
ones({2, 4, 5}),
|
||||||
|
ones({3, 6, 5}),
|
||||||
|
ones({3, 4, 8}),
|
||||||
|
ones({9, 4, 7})})
|
||||||
|
.first;
|
||||||
|
CHECK_EQ(path, expected);
|
||||||
|
|
||||||
|
expected = {{0, 2}, {0, 3}, {0, 2}, {0, 1}};
|
||||||
|
path = einsum_path(
|
||||||
|
"ea,fb,abcd,gc,hd->efgh",
|
||||||
|
{ones({10, 10}),
|
||||||
|
ones({10, 10}),
|
||||||
|
ones({10, 10, 10, 10}),
|
||||||
|
ones({10, 10}),
|
||||||
|
ones({10, 10})})
|
||||||
|
.first;
|
||||||
|
CHECK_EQ(path, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test einsum") {
|
||||||
|
CHECK_THROWS(einsum("i,j", {array({1.0})}));
|
||||||
|
CHECK_THROWS(einsum("ijk", {full({2, 2}, 2.0f)}));
|
||||||
|
CHECK_THROWS(einsum("", {}));
|
||||||
|
CHECK_THROWS(einsum("ij", {array({1, 2})}));
|
||||||
|
CHECK_THROWS(einsum("", {array({1, 2})}));
|
||||||
|
CHECK_THROWS(einsum("i,ij", {array({1, 2}), array({2, 3})}));
|
||||||
|
CHECK_THROWS(einsum("i,i", {array({1, 2}), array({2, 3, 4})}));
|
||||||
|
CHECK_THROWS(einsum("i->ii", {array({1, 2})}));
|
||||||
|
CHECK_THROWS(einsum("12", {zeros({4, 4})}));
|
||||||
|
CHECK_THROWS(einsum("ii->i", {zeros({3, 2})}));
|
||||||
|
|
||||||
|
auto x = einsum("jki", {full({2, 3, 4}, 3.0f)});
|
||||||
|
auto expected = full({4, 2, 3}, 3.0f);
|
||||||
|
CHECK_EQ(allclose(x, expected).item<bool>(), true);
|
||||||
|
|
||||||
|
x = einsum("ij,jk->ik", {full({2, 2}, 2.0f), full({2, 2}, 3.0f)});
|
||||||
|
expected = array({12.0f, 12.0f, 12.0f, 12.0f}, {2, 2});
|
||||||
|
CHECK_EQ(allclose(x, expected).item<bool>(), true);
|
||||||
|
|
||||||
|
x = einsum("i,j->ij", {full({2}, 15.0f), full({4}, 20.0f)});
|
||||||
|
expected = full({2, 4}, 300.0f);
|
||||||
|
CHECK_EQ(allclose(x, expected).item<bool>(), true);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user