* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-07-25 09:36:44 -07:00
committed by GitHub
parent 7f914365fd
commit baf9fa5f42
13 changed files with 1498 additions and 65 deletions

View File

@@ -12,6 +12,7 @@
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/einsum.h"
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/load.h"
@@ -40,15 +41,6 @@ double scalar_to_double(Scalar s) {
}
void init_ops(nb::module_& m) {
// TODO, remove deprecation errors in a future release
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
});
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
});
m.def(
"reshape",
&reshape,
@@ -1238,7 +1230,8 @@ void init_ops(nb::module_& m) {
a (array): Input array.
Returns:
array: The unchanged input ``a`` but without gradient flowing
array:
The unchanged input ``a`` but without gradient flowing
through it.
)pbdoc");
m.def(
@@ -2936,6 +2929,9 @@ void init_ops(nb::module_& m) {
reverse (bool): Perform the cumulative sum in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"cumprod",
@@ -2969,6 +2965,9 @@ void init_ops(nb::module_& m) {
reverse (bool): Perform the cumulative product in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"cummax",
@@ -3002,6 +3001,9 @@ void init_ops(nb::module_& m) {
reverse (bool): Perform the cumulative maximum in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"cummin",
@@ -3035,6 +3037,9 @@ void init_ops(nb::module_& m) {
reverse (bool): Perform the cumulative minimum in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"conj",
@@ -3052,6 +3057,9 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array
Returns:
array: The output array.
)pbdoc");
m.def(
"conjugate",
@@ -3069,6 +3077,9 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array
Returns:
array: The output array.
)pbdoc");
m.def(
"convolve",
@@ -3492,14 +3503,11 @@ void init_ops(nb::module_& m) {
Args:
file (file, str): File in which the array is saved.
format (str, optional): Format of the file. If ``None``, the
format
is inferred from the file extension. Supported formats:
``npy``,
``npz``, and ``safetensors``. Default: ``None``.
format is inferred from the file extension. Supported formats:
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
return_metadata (bool, optional): Load the metadata for formats
which
support matadata. The metadata will be returned as an
additional dictionary.
which support matadata. The metadata will be returned as an
additional dictionary. Default: ``False``.
Returns:
array or dict:
A single array if loading from a ``.npy`` file or a dict
@@ -3551,9 +3559,9 @@ void init_ops(nb::module_& m) {
Args:
file (file, str): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to
be saved. metadata (dict(str, Union[array, str, list(str)])):
The dictionary of
metadata to be saved. The values can be a scalar or 1D
be saved.
metadata (dict(str, Union[array, str, list(str)])): The dictionary
of metadata to be saved. The values can be a scalar or 1D
obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.
)pbdoc");
m.def(
@@ -3643,11 +3651,11 @@ void init_ops(nb::module_& m) {
biases (array): The biases to use per ``group_size`` elements of ``w``
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. (default: ``True``)
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. (default: ``64``)
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
``w``. Default: ``4``.
Returns:
array: The result of the multiplication of ``x`` with ``w``.
@@ -3700,9 +3708,9 @@ void init_ops(nb::module_& m) {
Args:
w (array): Matrix to be quantized
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: ``64``)
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``4``)
``w`` in the returned quantized matrix. Default: ``4``.
Returns:
tuple: A tuple containing
@@ -3740,9 +3748,9 @@ void init_ops(nb::module_& m) {
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: ``64``)
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
``w``. Default: ``4``.
Returns:
array: The dequantized version of ``w``
@@ -3779,15 +3787,15 @@ void init_ops(nb::module_& m) {
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``)
rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``)
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. (default: ``True``)
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. (default: ``64``)
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
``w``. Default: ``4``.
Returns:
array: The result of the multiplication of ``x`` with ``w``
@@ -3827,7 +3835,7 @@ void init_ops(nb::module_& m) {
sum over. If an integer is provided, then sum over the last
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. (default: 2)
corresponding dimensions of ``a`` and ``b``. Default: 2.
Returns:
array: The tensor dot product.
@@ -3958,11 +3966,13 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
mask_out (array, optional): Mask for output (default: ``None``)
mask_lhs (array, optional): Mask for a (default: ``None``)
mask_rhs (array, optional): Mask for b (default: ``None``)
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``.
mask_out (array, optional): Mask for output. Default: ``None``.
mask_lhs (array, optional): Mask for ``a``. Default: ``None``.
mask_rhs (array, optional): Mask for ``b``. Default: ``None``.
Returns:
array: The output array.
)pbdoc");
m.def(
"gather_mm",
@@ -3996,9 +4006,11 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array.
b (array): Input array.
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
Returns:
array: The output array.
)pbdoc");
m.def(
"diagonal",
@@ -4406,4 +4418,57 @@ void init_ops(nb::module_& m) {
Returns:
array: The transformed array.
)pbdoc");
m.def(
"einsum_path",
[](const std::string& equation, const nb::args& operands) {
auto arrays_list = nb::cast<std::vector<array>>(operands);
auto [path, str] = einsum_path(equation, arrays_list);
// Convert to list of tuples
std::vector<nb::tuple> tuple_path;
for (auto& p : path) {
tuple_path.push_back(nb::tuple(nb::cast(p)));
}
return std::make_pair(tuple_path, str);
},
"subscripts"_a,
"operands"_a,
nb::sig("def einsum_path(subscripts: str, *operands)"),
R"pbdoc(
Compute the contraction order for the given Einstein summation.
Args:
subscripts (str): The Einstein summation convention equation.
*operands (array): The input arrays.
Returns:
tuple(list(tuple(int, int)), str):
The einsum path and a string containing information about the
chosen path.
)pbdoc");
m.def(
"einsum",
[](const std::string& subscripts,
const nb::args& operands,
StreamOrDevice s) {
auto arrays_list = nb::cast<std::vector<array>>(operands);
return einsum(subscripts, arrays_list, s);
},
"subscripts"_a,
"operands"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the Einstein summation convention on the operands.
Args:
subscripts (str): The Einstein summation convention equation.
*operands (array): The input arrays.
Returns:
array: The output array.
)pbdoc");
}

View File

@@ -99,7 +99,7 @@ void init_random(nb::module_& parent_module) {
Args:
key (array): Input key to split.
num (int, optional): Number of sub keys. Default is 2.
num (int, optional): Number of sub keys. Default: ``2``.
Returns:
array: The array of sub keys with ``num`` as its first dimension.
@@ -137,11 +137,13 @@ void init_random(nb::module_& parent_module) {
broadcastable to ``shape``.
Args:
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
shape (list(int), optional): Shape of the output. Default is ``()``.
low (scalar or array, optional): Lower bound of the distribution.
Default: ``0``.
high (scalar or array, optional): Upper bound of the distribution.
Default: ``1``.
shape (list(int), optional): Shape of the output. Default:``()``.
dtype (Dtype, optional): Type of the output. Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
dtype (Dtype, optional): Type of the output. Default is ``float32``.
Returns:
array: The output array random values.
@@ -250,9 +252,9 @@ void init_random(nb::module_& parent_module) {
Args:
low (scalar or array): Lower bound of the interval.
high (scalar or array): Upper bound of the interval.
shape (list(int), optional): Shape of the output. Defaults to ``()``.
dtype (Dtype, optional): Type of the output. Defaults to ``int32``.
key (array, optional): A PRNG key. Default: None.
shape (list(int), optional): Shape of the output. Default: ``()``.
dtype (Dtype, optional): Type of the output. Default: ``int32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The array of random integers.
@@ -286,10 +288,10 @@ void init_random(nb::module_& parent_module) {
Args:
p (float or array, optional): Parameter of the Bernoulli
distribution. Default is 0.5.
shape (list(int), optional): Shape of the output. The default
shape is ``p.shape``.
key (array, optional): A PRNG key. Default: None.
distribution. Default: ``0.5``.
shape (list(int), optional): Shape of the output.
Default: ``p.shape``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The array of random integers.
@@ -331,10 +333,10 @@ void init_random(nb::module_& parent_module) {
lower (scalar or array): Lower bound of the domain.
upper (scalar or array): Upper bound of the domain.
shape (list(int), optional): The shape of the output.
Default is ``()``.
Default:``()``.
dtype (Dtype, optional): The data type of the output.
Default is ``float32``.
key (array, optional): A PRNG key. Default: None.
Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.
@@ -362,7 +364,7 @@ void init_random(nb::module_& parent_module) {
Args:
shape (list(int)): The shape of the output.
key (array, optional): A PRNG key. Default: None.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The :class:`array` with shape ``shape`` and
@@ -407,14 +409,14 @@ void init_random(nb::module_& parent_module) {
Args:
logits (array): The *unnormalized* categorical distribution(s).
axis (int, optional): The axis which specifies the distribution.
Default is ``-1``.
Default: ``-1``.
shape (list(int), optional): The shape of the output. This must
be broadcast compatable with ``logits.shape`` with the ``axis``
dimension removed. Default: ``None``
num_samples (int, optional): The number of samples to draw from each
of the categorical distributions in ``logits``. The output will have
``num_samples`` in the last dimension. Default: ``None``.
key (array, optional): A PRNG key. Default: None.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The ``shape``-sized output array with type ``uint32``.
@@ -442,11 +444,12 @@ void init_random(nb::module_& parent_module) {
Sample numbers from a Laplace distribution.
Args:
shape (list(int), optional): Shape of the output. Default is ``()``.
dtype (Dtype, optional): Type of the output. Default is ``float32``.
loc (float, optional): Mean of the distribution. Default is ``0.0``.
scale (float, optional): The scale "b" of the Laplace distribution. Default is ``1.0``.
key (array, optional): A PRNG key. Default: None.
shape (list(int), optional): Shape of the output. Default: ``()``.
dtype (Dtype, optional): Type of the output. Default: ``float32``.
loc (float, optional): Mean of the distribution. Default: ``0.0``.
scale (float, optional): The scale "b" of the Laplace distribution.
Default:``1.0``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.

318
python/tests/test_einsum.py Normal file
View File

@@ -0,0 +1,318 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
import numpy as np
class TestEinsum(mlx_tests.MLXTestCase):
def test_simple_path(self):
a = mx.zeros((5, 5))
path = mx.einsum_path("ii", a)
self.assertEqual(path[0], [(0,)])
path = mx.einsum_path("ij->i", a)
self.assertEqual(path[0], [(0,)])
path = mx.einsum_path("ii->i", a)
self.assertEqual(path[0], [(0,)])
a = mx.zeros((5, 8))
b = mx.zeros((8, 3))
path = mx.einsum_path("ij,jk", a, b)
self.assertEqual(path[0], [(0, 1)])
path = mx.einsum_path("ij,jk -> ijk", a, b)
self.assertEqual(path[0], [(0, 1)])
a = mx.zeros((5, 8))
b = mx.zeros((8, 3))
c = mx.zeros((3, 7))
path = mx.einsum_path("ij,jk,kl", a, b, c)
self.assertEqual(path[0], [(0, 1), (0, 1)])
a = mx.zeros((5, 8))
b = mx.zeros((8, 10))
c = mx.zeros((10, 7))
path = mx.einsum_path("ij,jk,kl", a, b, c)
self.assertEqual(path[0], [(1, 2), (0, 1)])
def test_longer_paths(self):
chars = "abcdefghijklmopqABC"
sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
dim_dict = {c: s for c, s in zip(chars, sizes)}
cases = [
"eb,cb,fb->cef",
"dd,fb,be,cdb->cef",
"dd,fb,be,cdb->cef",
"bca,cdb,dbf,afc->",
"dcc,fce,ea,dbf->ab",
"dcc,fce,ea,dbf->ab",
]
for case in cases:
subscripts = case[: case.find("->")].split(",")
inputs = []
for s in subscripts:
shape = [dim_dict[c] for c in s]
inputs.append(np.ones(shape))
np_path = np.einsum_path(case, *inputs)
inputs = [mx.array(i) for i in inputs]
mx_path = mx.einsum_path(case, *inputs)
self.assertEqual(np_path[0][1:], mx_path[0])
def test_simple_einsum(self):
a = mx.arange(4 * 4).reshape(4, 4)
a_mx = mx.einsum("ii->i", a)
a_np = np.einsum("ii->i", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 2 * 2).reshape(2, 2, 2)
a_mx = mx.einsum("iii->i", a)
a_np = np.einsum("iii->i", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 2 * 3 * 3).reshape(2, 2, 3, 3)
a_mx = mx.einsum("iijj->ij", a)
a_np = np.einsum("iijj->ij", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3)
a_mx = mx.einsum("ijij->ij", a)
a_np = np.einsum("ijij->ij", a)
self.assertTrue(np.array_equal(a_mx, a_np))
# Test some simple reductions
a = mx.arange(2 * 2).reshape(2, 2)
a_mx = mx.einsum("ii", a)
a_np = np.einsum("ii", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 4).reshape(2, 4)
a_mx = mx.einsum("ij->", a)
a_np = np.einsum("ij->", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 4).reshape(2, 4)
a_mx = mx.einsum("ij->i", a)
a_np = np.einsum("ij->i", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 4).reshape(2, 4)
a_mx = mx.einsum("ij->j", a)
a_np = np.einsum("ij->j", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 2 * 2).reshape(2, 2, 2)
a_mx = mx.einsum("iii->", a)
a_np = np.einsum("iii->", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3)
a_mx = mx.einsum("ijij->j", a)
a_np = np.einsum("ijij->j", a)
self.assertTrue(np.array_equal(a_mx, a_np))
# Test some simple transposes
a = mx.arange(2 * 4).reshape(2, 4)
a_mx = mx.einsum("ij", a)
a_np = np.einsum("ij", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 4).reshape(2, 4)
a_mx = mx.einsum("ij->ji", a)
a_np = np.einsum("ij->ji", a)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.arange(2 * 3 * 4).reshape(2, 3, 4)
a_mx = mx.einsum("ijk->jki", a)
a_np = np.einsum("ijk->jki", a)
self.assertTrue(np.array_equal(a_mx, a_np))
def test_two_input_einsum(self):
# Matmul
a = mx.full((2, 8), 1.0)
b = mx.full((8, 2), 1.0)
a_mx = mx.einsum("ik,kj", a, b)
a_np = np.einsum("ik,kj", a, b)
self.assertTrue(np.array_equal(a_mx, a_np))
# Matmul + transpose
a = mx.full((2, 8), 1.0)
b = mx.full((8, 3), 1.0)
a_mx = mx.einsum("ik,kj->ji", a, b)
a_np = np.einsum("ik,kj->ji", a, b)
self.assertTrue(np.array_equal(a_mx, a_np))
# Inner product
a = mx.full((4,), 1.0)
b = mx.full((4,), 1.0)
a_mx = mx.einsum("i,i", a, b)
a_np = np.einsum("i,i", a, b)
self.assertTrue(np.array_equal(a_mx, a_np))
# Outer product
a = mx.full((4,), 0.5)
b = mx.full((6,), 2.0)
a_mx = mx.einsum("i,j->ij", a, b)
a_np = np.einsum("i,j->ij", a, b)
self.assertTrue(np.array_equal(a_mx, a_np))
# Elementwise multiply
a = mx.full((2, 8), 1.0)
b = mx.full((2, 8), 1.0)
a_mx = mx.einsum("ij,ij->ij", a, b)
a_np = np.einsum("ij,ij->ij", a, b)
self.assertTrue(np.array_equal(a_mx, a_np))
# Medley
a = mx.full((2, 8, 3, 5), 1.0)
b = mx.full((3, 7, 5, 2), 1.0)
a_mx = mx.einsum("abcd,fgda->bfca", a, b)
a_np = np.einsum("abcd,fgda->bfca", a, b)
self.assertTrue(np.array_equal(a_mx, a_np))
def test_sum_first(self):
a = mx.full((5, 8), 1.0)
b = mx.full((8, 2), 1.0)
a_mx = mx.einsum("ab,bc->c", a, b)
a_np = np.einsum("ab,bc->c", a, b)
self.assertTrue(np.array_equal(a_mx, a_np))
def test_broadcasting(self):
a = mx.full((5, 1), 1.0)
b = mx.full((8, 2), 1.0)
a_mx = mx.einsum("ab,bc->c", a, b)
return
a_np = np.einsum("ab,bc->c", a, b)
self.assertTrue(np.array_equal(a_mx, a_np))
a = mx.random.uniform(shape=(5, 1, 3, 1))
b = mx.random.uniform(shape=(1, 7, 1, 2))
a_mx = mx.einsum("abcd,cdab->abcd", a, b)
a_np = np.einsum("abcd,cdab->abcd", a, b)
self.assertTrue(np.allclose(a_mx, a_np))
def test_attention(self):
q = mx.random.uniform(shape=(2, 3, 4, 5))
k = mx.random.uniform(shape=(2, 3, 4, 5))
v = mx.random.uniform(shape=(2, 3, 4, 5))
s = mx.einsum("itjk,iujk->ijtu", q, k)
out_mx = mx.einsum("ijtu,iujk->itjk", s, v)
s = np.einsum("itjk,iujk->ijtu", q, k)
out_np = np.einsum("ijtu,iujk->itjk", s, v)
self.assertTrue(np.allclose(out_mx, out_np))
def test_multi_input_einsum(self):
a = mx.ones((3, 4, 5))
out_mx = mx.einsum("ijk,lmk,ijf->lf", a, a, a)
out_np = np.einsum("ijk,lmk,ijf->lf", a, a, a)
self.assertTrue(np.allclose(out_mx, out_np))
def test_opt_einsum_test_cases(self):
# Test cases from
# https://github.com/dgasmith/opt_einsum/blob/c826bb7df16f470a69f7bf90598fc27586209d11/opt_einsum/tests/test_contract.py#L11
tests = [
# Test hadamard-like products
"a,ab,abc->abc",
"a,b,ab->ab",
# Test index-transformations
"ea,fb,gc,hd,abcd->efgh",
"ea,fb,abcd,gc,hd->efgh",
"abcd,ea,fb,gc,hd->efgh",
# Test complex contractions
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
"cd,bdhe,aidb,hgca,gc,hgibcd,hgac",
"abhe,hidj,jgba,hiab,gab",
"bde,cdh,agdb,hica,ibd,hgicd,hiac",
"chd,bde,agbc,hiad,hgc,hgi,hiad",
"chd,bde,agbc,hiad,bdi,cgh,agdb",
"bdhe,acad,hiab,agac,hibd",
# Test collapse
"ab,ab,c->",
"ab,ab,c->c",
"ab,ab,cd,cd->",
"ab,ab,cd,cd->ac",
"ab,ab,cd,cd->cd",
"ab,ab,cd,cd,ef,ef->",
# Test outer prodcuts
"ab,cd,ef->abcdef",
"ab,cd,ef->acdf",
"ab,cd,de->abcde",
"ab,cd,de->be",
"ab,bcd,cd->abcd",
"ab,bcd,cd->abd",
# Random test cases that have previously failed
"eb,cb,fb->cef",
"dd,fb,be,cdb->cef",
"bca,cdb,dbf,afc->",
"dcc,fce,ea,dbf->ab",
"fdf,cdd,ccd,afe->ae",
"abcd,ad",
"ed,fcd,ff,bcf->be",
"baa,dcf,af,cde->be",
"bd,db,eac->ace",
"fff,fae,bef,def->abd",
"efc,dbc,acf,fd->abe",
# Inner products
"ab,ab",
"ab,ba",
"abc,abc",
"abc,bac",
"abc,cba",
# GEMM test cases
"ab,bc",
"ab,cb",
"ba,bc",
"ba,cb",
"abcd,cd",
"abcd,ab",
"abcd,cdef",
"abcd,cdef->feba",
"abcd,efdc",
# Inner then dot
"aab,bc->ac",
"ab,bcc->ac",
"aab,bcc->ac",
"baa,bcc->ac",
"aab,ccb->ac",
# Randomly build test caes
"aab,fa,df,ecc->bde",
"ecb,fef,bad,ed->ac",
"bcf,bbb,fbf,fc->",
"bb,ff,be->e",
"bcb,bb,fc,fff->",
"fbb,dfd,fc,fc->",
"afd,ba,cc,dc->bf",
"adb,bc,fa,cfc->d",
"bbd,bda,fc,db->acf",
"dba,ead,cad->bce",
"aef,fbc,dca->bde",
]
size_dict = dict(zip("abcdefghij", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3]))
def inputs_for_case(test_case):
inputs = test_case.split("->")[0].split(",")
return [
mx.random.uniform(shape=tuple(size_dict[c] for c in inp))
for inp in inputs
]
for test_case in tests:
inputs = inputs_for_case(test_case)
np_out = np.einsum(test_case, *inputs)
mx_out = mx.einsum(test_case, *inputs)
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
if __name__ == "__main__":
unittest.main()