Einsum ellipsis (#1788)

This commit is contained in:
Angelos Katharopoulos
2025-01-25 01:28:03 -08:00
committed by GitHub
parent e6a7ab9675
commit 72146fc4cd
2 changed files with 124 additions and 10 deletions

View File

@@ -313,6 +313,51 @@ class TestEinsum(mlx_tests.MLXTestCase):
mx_out = mx.einsum(test_case, *inputs)
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
def test_ellipses(self):
size_dict = dict(zip("abcdefghij", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3]))
def inputs_for_case(test_case):
inputs = test_case.split("->")[0].split(",")
return [
mx.random.uniform(shape=tuple(size_dict[c] for c in inp))
for inp in inputs
]
tests = [
("abc->ab", "...c->..."),
("abcd->ad", "a...d->..."),
("abij,abgj->abig", "...ij,...gj->...ig"),
("abij,abgj->abig", "...ij,...gj->..."),
("abhh->abh", "...hh->...h"),
("abhh->abh", "...hh->...h"),
("bch,abcj->abchj", "...h,...j->...hj"),
("bc,cd->bd", "...c,cd"),
("abc,acd->bd", "...bc,...cd"),
("abcd,c->abd", "...cd,c"),
("abcd,c->abd", "...cd,c..."),
("abcd,c->abd", "...cd,c...->d..."),
("abc,b->abc", "ab...,b...->ab..."),
("abc,b->abc", "ab...,...b->ab..."),
("abc,b->abc", "ab...,b->ab..."),
("ab,bc->ac", "ab...,b...->a..."),
("ab,bc->ac", "ab...,...bc->a...c"),
("ab,bc->ac", "ab,b...->a..."),
("abcdef,defg->abcg", "...def,defg->...g"),
]
for test_case in tests:
inputs = inputs_for_case(test_case[0])
np_out = np.einsum(test_case[1], *inputs)
mx_out = mx.einsum(test_case[1], *inputs)
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
error_tests = [
("abc,abc->ab", "a...b...c,a...b...c->abc"),
]
for test_case in error_tests:
inputs = inputs_for_case(test_case[0])
with self.assertRaises(ValueError):
mx.einsum(test_case[1], *inputs)
if __name__ == "__main__":
unittest.main()