mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Einsum (#1269)
* einsum initial * fix comma break * sum axis was wrong * small cleanups * python binding * changed bindings to resemble numpy * remove todo comment * comment changes * add count of operands/inputs * fail fast if operands list is empty * ignore comma if no output * einsum path matching numpy * getting somewhere with path * remove print * it passes the first test * moved einsum tests to seperate file * seperated einsum path * moved einsum naive * remove space from equation * fast fail if no operands passed * update tests and remove printf * small cleanup * some more cleanups * removed python helper file * ack * utilize std for finding min in vector * duplicate def * remove the tuple as it was unreadable * moved einsum_naive back to ops * remaining isn't needed * avoid creating another set * cleanup * greedy path, start of naive einsum * more einsum * fix some bugs * some more fixes, tests pass * benchmark * some simplify * fix einsum and test Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> * add a bunch more tests and fix a bunch more bugs * some docs nits --------- Co-authored-by: dc-dc-dc <dgcruz983@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/einsum.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/load.h"
|
||||
@@ -40,15 +41,6 @@ double scalar_to_double(Scalar s) {
|
||||
}
|
||||
|
||||
void init_ops(nb::module_& m) {
|
||||
// TODO, remove deprecation errors in a future release
|
||||
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
|
||||
throw std::invalid_argument(
|
||||
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
|
||||
});
|
||||
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
|
||||
throw std::invalid_argument(
|
||||
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
|
||||
});
|
||||
m.def(
|
||||
"reshape",
|
||||
&reshape,
|
||||
@@ -1238,7 +1230,8 @@ void init_ops(nb::module_& m) {
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The unchanged input ``a`` but without gradient flowing
|
||||
array:
|
||||
The unchanged input ``a`` but without gradient flowing
|
||||
through it.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -2936,6 +2929,9 @@ void init_ops(nb::module_& m) {
|
||||
reverse (bool): Perform the cumulative sum in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cumprod",
|
||||
@@ -2969,6 +2965,9 @@ void init_ops(nb::module_& m) {
|
||||
reverse (bool): Perform the cumulative product in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cummax",
|
||||
@@ -3002,6 +3001,9 @@ void init_ops(nb::module_& m) {
|
||||
reverse (bool): Perform the cumulative maximum in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cummin",
|
||||
@@ -3035,6 +3037,9 @@ void init_ops(nb::module_& m) {
|
||||
reverse (bool): Perform the cumulative minimum in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conj",
|
||||
@@ -3052,6 +3057,9 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conjugate",
|
||||
@@ -3069,6 +3077,9 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"convolve",
|
||||
@@ -3492,14 +3503,11 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
format (str, optional): Format of the file. If ``None``, the
|
||||
format
|
||||
is inferred from the file extension. Supported formats:
|
||||
``npy``,
|
||||
``npz``, and ``safetensors``. Default: ``None``.
|
||||
format is inferred from the file extension. Supported formats:
|
||||
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
|
||||
return_metadata (bool, optional): Load the metadata for formats
|
||||
which
|
||||
support matadata. The metadata will be returned as an
|
||||
additional dictionary.
|
||||
which support matadata. The metadata will be returned as an
|
||||
additional dictionary. Default: ``False``.
|
||||
Returns:
|
||||
array or dict:
|
||||
A single array if loading from a ``.npy`` file or a dict
|
||||
@@ -3551,9 +3559,9 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||
be saved. metadata (dict(str, Union[array, str, list(str)])):
|
||||
The dictionary of
|
||||
metadata to be saved. The values can be a scalar or 1D
|
||||
be saved.
|
||||
metadata (dict(str, Union[array, str, list(str)])): The dictionary
|
||||
of metadata to be saved. The values can be a scalar or 1D
|
||||
obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -3643,11 +3651,11 @@ void init_ops(nb::module_& m) {
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. (default: ``64``)
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
``w``. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
@@ -3700,9 +3708,9 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: ``64``)
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``4``)
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
@@ -3740,9 +3748,9 @@ void init_ops(nb::module_& m) {
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: ``64``)
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
``w``. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
@@ -3779,15 +3787,15 @@ void init_ops(nb::module_& m) {
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``)
|
||||
rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``)
|
||||
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
||||
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. (default: ``64``)
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
``w``. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
@@ -3827,7 +3835,7 @@ void init_ops(nb::module_& m) {
|
||||
sum over. If an integer is provided, then sum over the last
|
||||
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
||||
``b``. If a list of lists is provided, then sum over the
|
||||
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
||||
corresponding dimensions of ``a`` and ``b``. Default: 2.
|
||||
|
||||
Returns:
|
||||
array: The tensor dot product.
|
||||
@@ -3958,11 +3966,13 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
||||
mask_out (array, optional): Mask for output (default: ``None``)
|
||||
mask_lhs (array, optional): Mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Mask for b (default: ``None``)
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``.
|
||||
mask_out (array, optional): Mask for output. Default: ``None``.
|
||||
mask_lhs (array, optional): Mask for ``a``. Default: ``None``.
|
||||
mask_rhs (array, optional): Mask for ``b``. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gather_mm",
|
||||
@@ -3996,9 +4006,11 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
|
||||
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
|
||||
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
|
||||
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diagonal",
|
||||
@@ -4406,4 +4418,57 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The transformed array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"einsum_path",
|
||||
[](const std::string& equation, const nb::args& operands) {
|
||||
auto arrays_list = nb::cast<std::vector<array>>(operands);
|
||||
auto [path, str] = einsum_path(equation, arrays_list);
|
||||
// Convert to list of tuples
|
||||
std::vector<nb::tuple> tuple_path;
|
||||
for (auto& p : path) {
|
||||
tuple_path.push_back(nb::tuple(nb::cast(p)));
|
||||
}
|
||||
return std::make_pair(tuple_path, str);
|
||||
},
|
||||
"subscripts"_a,
|
||||
"operands"_a,
|
||||
nb::sig("def einsum_path(subscripts: str, *operands)"),
|
||||
R"pbdoc(
|
||||
|
||||
Compute the contraction order for the given Einstein summation.
|
||||
|
||||
Args:
|
||||
subscripts (str): The Einstein summation convention equation.
|
||||
*operands (array): The input arrays.
|
||||
|
||||
Returns:
|
||||
tuple(list(tuple(int, int)), str):
|
||||
The einsum path and a string containing information about the
|
||||
chosen path.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"einsum",
|
||||
[](const std::string& subscripts,
|
||||
const nb::args& operands,
|
||||
StreamOrDevice s) {
|
||||
auto arrays_list = nb::cast<std::vector<array>>(operands);
|
||||
return einsum(subscripts, arrays_list, s);
|
||||
},
|
||||
"subscripts"_a,
|
||||
"operands"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
|
||||
Perform the Einstein summation convention on the operands.
|
||||
|
||||
Args:
|
||||
subscripts (str): The Einstein summation convention equation.
|
||||
*operands (array): The input arrays.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -99,7 +99,7 @@ void init_random(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
key (array): Input key to split.
|
||||
num (int, optional): Number of sub keys. Default is 2.
|
||||
num (int, optional): Number of sub keys. Default: ``2``.
|
||||
|
||||
Returns:
|
||||
array: The array of sub keys with ``num`` as its first dimension.
|
||||
@@ -137,11 +137,13 @@ void init_random(nb::module_& parent_module) {
|
||||
broadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
|
||||
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
|
||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||
low (scalar or array, optional): Lower bound of the distribution.
|
||||
Default: ``0``.
|
||||
high (scalar or array, optional): Upper bound of the distribution.
|
||||
Default: ``1``.
|
||||
shape (list(int), optional): Shape of the output. Default:``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default: ``float32``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||
|
||||
Returns:
|
||||
array: The output array random values.
|
||||
@@ -250,9 +252,9 @@ void init_random(nb::module_& parent_module) {
|
||||
Args:
|
||||
low (scalar or array): Lower bound of the interval.
|
||||
high (scalar or array): Upper bound of the interval.
|
||||
shape (list(int), optional): Shape of the output. Defaults to ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Defaults to ``int32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
shape (list(int), optional): Shape of the output. Default: ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default: ``int32``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The array of random integers.
|
||||
@@ -286,10 +288,10 @@ void init_random(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
p (float or array, optional): Parameter of the Bernoulli
|
||||
distribution. Default is 0.5.
|
||||
shape (list(int), optional): Shape of the output. The default
|
||||
shape is ``p.shape``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
distribution. Default: ``0.5``.
|
||||
shape (list(int), optional): Shape of the output.
|
||||
Default: ``p.shape``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The array of random integers.
|
||||
@@ -331,10 +333,10 @@ void init_random(nb::module_& parent_module) {
|
||||
lower (scalar or array): Lower bound of the domain.
|
||||
upper (scalar or array): Upper bound of the domain.
|
||||
shape (list(int), optional): The shape of the output.
|
||||
Default is ``()``.
|
||||
Default:``()``.
|
||||
dtype (Dtype, optional): The data type of the output.
|
||||
Default is ``float32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
Default: ``float32``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
@@ -362,7 +364,7 @@ void init_random(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
shape (list(int)): The shape of the output.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The :class:`array` with shape ``shape`` and
|
||||
@@ -407,14 +409,14 @@ void init_random(nb::module_& parent_module) {
|
||||
Args:
|
||||
logits (array): The *unnormalized* categorical distribution(s).
|
||||
axis (int, optional): The axis which specifies the distribution.
|
||||
Default is ``-1``.
|
||||
Default: ``-1``.
|
||||
shape (list(int), optional): The shape of the output. This must
|
||||
be broadcast compatable with ``logits.shape`` with the ``axis``
|
||||
dimension removed. Default: ``None``
|
||||
num_samples (int, optional): The number of samples to draw from each
|
||||
of the categorical distributions in ``logits``. The output will have
|
||||
``num_samples`` in the last dimension. Default: ``None``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The ``shape``-sized output array with type ``uint32``.
|
||||
@@ -442,11 +444,12 @@ void init_random(nb::module_& parent_module) {
|
||||
Sample numbers from a Laplace distribution.
|
||||
|
||||
Args:
|
||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||
loc (float, optional): Mean of the distribution. Default is ``0.0``.
|
||||
scale (float, optional): The scale "b" of the Laplace distribution. Default is ``1.0``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
shape (list(int), optional): Shape of the output. Default: ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default: ``float32``.
|
||||
loc (float, optional): Mean of the distribution. Default: ``0.0``.
|
||||
scale (float, optional): The scale "b" of the Laplace distribution.
|
||||
Default:``1.0``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
|
318
python/tests/test_einsum.py
Normal file
318
python/tests/test_einsum.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestEinsum(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_simple_path(self):
|
||||
a = mx.zeros((5, 5))
|
||||
path = mx.einsum_path("ii", a)
|
||||
self.assertEqual(path[0], [(0,)])
|
||||
|
||||
path = mx.einsum_path("ij->i", a)
|
||||
self.assertEqual(path[0], [(0,)])
|
||||
|
||||
path = mx.einsum_path("ii->i", a)
|
||||
self.assertEqual(path[0], [(0,)])
|
||||
|
||||
a = mx.zeros((5, 8))
|
||||
b = mx.zeros((8, 3))
|
||||
path = mx.einsum_path("ij,jk", a, b)
|
||||
self.assertEqual(path[0], [(0, 1)])
|
||||
path = mx.einsum_path("ij,jk -> ijk", a, b)
|
||||
self.assertEqual(path[0], [(0, 1)])
|
||||
|
||||
a = mx.zeros((5, 8))
|
||||
b = mx.zeros((8, 3))
|
||||
c = mx.zeros((3, 7))
|
||||
path = mx.einsum_path("ij,jk,kl", a, b, c)
|
||||
|
||||
self.assertEqual(path[0], [(0, 1), (0, 1)])
|
||||
|
||||
a = mx.zeros((5, 8))
|
||||
b = mx.zeros((8, 10))
|
||||
c = mx.zeros((10, 7))
|
||||
path = mx.einsum_path("ij,jk,kl", a, b, c)
|
||||
self.assertEqual(path[0], [(1, 2), (0, 1)])
|
||||
|
||||
def test_longer_paths(self):
|
||||
chars = "abcdefghijklmopqABC"
|
||||
sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
|
||||
dim_dict = {c: s for c, s in zip(chars, sizes)}
|
||||
cases = [
|
||||
"eb,cb,fb->cef",
|
||||
"dd,fb,be,cdb->cef",
|
||||
"dd,fb,be,cdb->cef",
|
||||
"bca,cdb,dbf,afc->",
|
||||
"dcc,fce,ea,dbf->ab",
|
||||
"dcc,fce,ea,dbf->ab",
|
||||
]
|
||||
|
||||
for case in cases:
|
||||
subscripts = case[: case.find("->")].split(",")
|
||||
inputs = []
|
||||
for s in subscripts:
|
||||
shape = [dim_dict[c] for c in s]
|
||||
inputs.append(np.ones(shape))
|
||||
np_path = np.einsum_path(case, *inputs)
|
||||
|
||||
inputs = [mx.array(i) for i in inputs]
|
||||
mx_path = mx.einsum_path(case, *inputs)
|
||||
self.assertEqual(np_path[0][1:], mx_path[0])
|
||||
|
||||
def test_simple_einsum(self):
|
||||
a = mx.arange(4 * 4).reshape(4, 4)
|
||||
a_mx = mx.einsum("ii->i", a)
|
||||
a_np = np.einsum("ii->i", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 2 * 2).reshape(2, 2, 2)
|
||||
a_mx = mx.einsum("iii->i", a)
|
||||
a_np = np.einsum("iii->i", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 2 * 3 * 3).reshape(2, 2, 3, 3)
|
||||
a_mx = mx.einsum("iijj->ij", a)
|
||||
a_np = np.einsum("iijj->ij", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3)
|
||||
a_mx = mx.einsum("ijij->ij", a)
|
||||
a_np = np.einsum("ijij->ij", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
# Test some simple reductions
|
||||
a = mx.arange(2 * 2).reshape(2, 2)
|
||||
a_mx = mx.einsum("ii", a)
|
||||
a_np = np.einsum("ii", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 4).reshape(2, 4)
|
||||
a_mx = mx.einsum("ij->", a)
|
||||
a_np = np.einsum("ij->", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 4).reshape(2, 4)
|
||||
a_mx = mx.einsum("ij->i", a)
|
||||
a_np = np.einsum("ij->i", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 4).reshape(2, 4)
|
||||
a_mx = mx.einsum("ij->j", a)
|
||||
a_np = np.einsum("ij->j", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 2 * 2).reshape(2, 2, 2)
|
||||
a_mx = mx.einsum("iii->", a)
|
||||
a_np = np.einsum("iii->", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3)
|
||||
a_mx = mx.einsum("ijij->j", a)
|
||||
a_np = np.einsum("ijij->j", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
# Test some simple transposes
|
||||
a = mx.arange(2 * 4).reshape(2, 4)
|
||||
a_mx = mx.einsum("ij", a)
|
||||
a_np = np.einsum("ij", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 4).reshape(2, 4)
|
||||
a_mx = mx.einsum("ij->ji", a)
|
||||
a_np = np.einsum("ij->ji", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.arange(2 * 3 * 4).reshape(2, 3, 4)
|
||||
a_mx = mx.einsum("ijk->jki", a)
|
||||
a_np = np.einsum("ijk->jki", a)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
def test_two_input_einsum(self):
|
||||
|
||||
# Matmul
|
||||
a = mx.full((2, 8), 1.0)
|
||||
b = mx.full((8, 2), 1.0)
|
||||
a_mx = mx.einsum("ik,kj", a, b)
|
||||
a_np = np.einsum("ik,kj", a, b)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
# Matmul + transpose
|
||||
a = mx.full((2, 8), 1.0)
|
||||
b = mx.full((8, 3), 1.0)
|
||||
a_mx = mx.einsum("ik,kj->ji", a, b)
|
||||
a_np = np.einsum("ik,kj->ji", a, b)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
# Inner product
|
||||
a = mx.full((4,), 1.0)
|
||||
b = mx.full((4,), 1.0)
|
||||
a_mx = mx.einsum("i,i", a, b)
|
||||
a_np = np.einsum("i,i", a, b)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
# Outer product
|
||||
a = mx.full((4,), 0.5)
|
||||
b = mx.full((6,), 2.0)
|
||||
a_mx = mx.einsum("i,j->ij", a, b)
|
||||
a_np = np.einsum("i,j->ij", a, b)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
# Elementwise multiply
|
||||
a = mx.full((2, 8), 1.0)
|
||||
b = mx.full((2, 8), 1.0)
|
||||
a_mx = mx.einsum("ij,ij->ij", a, b)
|
||||
a_np = np.einsum("ij,ij->ij", a, b)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
# Medley
|
||||
a = mx.full((2, 8, 3, 5), 1.0)
|
||||
b = mx.full((3, 7, 5, 2), 1.0)
|
||||
a_mx = mx.einsum("abcd,fgda->bfca", a, b)
|
||||
a_np = np.einsum("abcd,fgda->bfca", a, b)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
def test_sum_first(self):
|
||||
a = mx.full((5, 8), 1.0)
|
||||
b = mx.full((8, 2), 1.0)
|
||||
a_mx = mx.einsum("ab,bc->c", a, b)
|
||||
a_np = np.einsum("ab,bc->c", a, b)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
def test_broadcasting(self):
|
||||
a = mx.full((5, 1), 1.0)
|
||||
b = mx.full((8, 2), 1.0)
|
||||
a_mx = mx.einsum("ab,bc->c", a, b)
|
||||
return
|
||||
a_np = np.einsum("ab,bc->c", a, b)
|
||||
self.assertTrue(np.array_equal(a_mx, a_np))
|
||||
|
||||
a = mx.random.uniform(shape=(5, 1, 3, 1))
|
||||
b = mx.random.uniform(shape=(1, 7, 1, 2))
|
||||
a_mx = mx.einsum("abcd,cdab->abcd", a, b)
|
||||
a_np = np.einsum("abcd,cdab->abcd", a, b)
|
||||
self.assertTrue(np.allclose(a_mx, a_np))
|
||||
|
||||
def test_attention(self):
|
||||
q = mx.random.uniform(shape=(2, 3, 4, 5))
|
||||
k = mx.random.uniform(shape=(2, 3, 4, 5))
|
||||
v = mx.random.uniform(shape=(2, 3, 4, 5))
|
||||
|
||||
s = mx.einsum("itjk,iujk->ijtu", q, k)
|
||||
out_mx = mx.einsum("ijtu,iujk->itjk", s, v)
|
||||
|
||||
s = np.einsum("itjk,iujk->ijtu", q, k)
|
||||
out_np = np.einsum("ijtu,iujk->itjk", s, v)
|
||||
|
||||
self.assertTrue(np.allclose(out_mx, out_np))
|
||||
|
||||
def test_multi_input_einsum(self):
|
||||
a = mx.ones((3, 4, 5))
|
||||
out_mx = mx.einsum("ijk,lmk,ijf->lf", a, a, a)
|
||||
out_np = np.einsum("ijk,lmk,ijf->lf", a, a, a)
|
||||
self.assertTrue(np.allclose(out_mx, out_np))
|
||||
|
||||
def test_opt_einsum_test_cases(self):
|
||||
# Test cases from
|
||||
# https://github.com/dgasmith/opt_einsum/blob/c826bb7df16f470a69f7bf90598fc27586209d11/opt_einsum/tests/test_contract.py#L11
|
||||
tests = [
|
||||
# Test hadamard-like products
|
||||
"a,ab,abc->abc",
|
||||
"a,b,ab->ab",
|
||||
# Test index-transformations
|
||||
"ea,fb,gc,hd,abcd->efgh",
|
||||
"ea,fb,abcd,gc,hd->efgh",
|
||||
"abcd,ea,fb,gc,hd->efgh",
|
||||
# Test complex contractions
|
||||
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
|
||||
"cd,bdhe,aidb,hgca,gc,hgibcd,hgac",
|
||||
"abhe,hidj,jgba,hiab,gab",
|
||||
"bde,cdh,agdb,hica,ibd,hgicd,hiac",
|
||||
"chd,bde,agbc,hiad,hgc,hgi,hiad",
|
||||
"chd,bde,agbc,hiad,bdi,cgh,agdb",
|
||||
"bdhe,acad,hiab,agac,hibd",
|
||||
# Test collapse
|
||||
"ab,ab,c->",
|
||||
"ab,ab,c->c",
|
||||
"ab,ab,cd,cd->",
|
||||
"ab,ab,cd,cd->ac",
|
||||
"ab,ab,cd,cd->cd",
|
||||
"ab,ab,cd,cd,ef,ef->",
|
||||
# Test outer prodcuts
|
||||
"ab,cd,ef->abcdef",
|
||||
"ab,cd,ef->acdf",
|
||||
"ab,cd,de->abcde",
|
||||
"ab,cd,de->be",
|
||||
"ab,bcd,cd->abcd",
|
||||
"ab,bcd,cd->abd",
|
||||
# Random test cases that have previously failed
|
||||
"eb,cb,fb->cef",
|
||||
"dd,fb,be,cdb->cef",
|
||||
"bca,cdb,dbf,afc->",
|
||||
"dcc,fce,ea,dbf->ab",
|
||||
"fdf,cdd,ccd,afe->ae",
|
||||
"abcd,ad",
|
||||
"ed,fcd,ff,bcf->be",
|
||||
"baa,dcf,af,cde->be",
|
||||
"bd,db,eac->ace",
|
||||
"fff,fae,bef,def->abd",
|
||||
"efc,dbc,acf,fd->abe",
|
||||
# Inner products
|
||||
"ab,ab",
|
||||
"ab,ba",
|
||||
"abc,abc",
|
||||
"abc,bac",
|
||||
"abc,cba",
|
||||
# GEMM test cases
|
||||
"ab,bc",
|
||||
"ab,cb",
|
||||
"ba,bc",
|
||||
"ba,cb",
|
||||
"abcd,cd",
|
||||
"abcd,ab",
|
||||
"abcd,cdef",
|
||||
"abcd,cdef->feba",
|
||||
"abcd,efdc",
|
||||
# Inner then dot
|
||||
"aab,bc->ac",
|
||||
"ab,bcc->ac",
|
||||
"aab,bcc->ac",
|
||||
"baa,bcc->ac",
|
||||
"aab,ccb->ac",
|
||||
# Randomly build test caes
|
||||
"aab,fa,df,ecc->bde",
|
||||
"ecb,fef,bad,ed->ac",
|
||||
"bcf,bbb,fbf,fc->",
|
||||
"bb,ff,be->e",
|
||||
"bcb,bb,fc,fff->",
|
||||
"fbb,dfd,fc,fc->",
|
||||
"afd,ba,cc,dc->bf",
|
||||
"adb,bc,fa,cfc->d",
|
||||
"bbd,bda,fc,db->acf",
|
||||
"dba,ead,cad->bce",
|
||||
"aef,fbc,dca->bde",
|
||||
]
|
||||
|
||||
size_dict = dict(zip("abcdefghij", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3]))
|
||||
|
||||
def inputs_for_case(test_case):
|
||||
inputs = test_case.split("->")[0].split(",")
|
||||
return [
|
||||
mx.random.uniform(shape=tuple(size_dict[c] for c in inp))
|
||||
for inp in inputs
|
||||
]
|
||||
|
||||
for test_case in tests:
|
||||
inputs = inputs_for_case(test_case)
|
||||
np_out = np.einsum(test_case, *inputs)
|
||||
mx_out = mx.einsum(test_case, *inputs)
|
||||
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user