mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Einsum ellipsis (#1788)
This commit is contained in:

committed by
GitHub

parent
e6a7ab9675
commit
72146fc4cd
@@ -313,6 +313,51 @@ class TestEinsum(mlx_tests.MLXTestCase):
|
||||
mx_out = mx.einsum(test_case, *inputs)
|
||||
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
|
||||
|
||||
def test_ellipses(self):
|
||||
size_dict = dict(zip("abcdefghij", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3]))
|
||||
|
||||
def inputs_for_case(test_case):
|
||||
inputs = test_case.split("->")[0].split(",")
|
||||
return [
|
||||
mx.random.uniform(shape=tuple(size_dict[c] for c in inp))
|
||||
for inp in inputs
|
||||
]
|
||||
|
||||
tests = [
|
||||
("abc->ab", "...c->..."),
|
||||
("abcd->ad", "a...d->..."),
|
||||
("abij,abgj->abig", "...ij,...gj->...ig"),
|
||||
("abij,abgj->abig", "...ij,...gj->..."),
|
||||
("abhh->abh", "...hh->...h"),
|
||||
("abhh->abh", "...hh->...h"),
|
||||
("bch,abcj->abchj", "...h,...j->...hj"),
|
||||
("bc,cd->bd", "...c,cd"),
|
||||
("abc,acd->bd", "...bc,...cd"),
|
||||
("abcd,c->abd", "...cd,c"),
|
||||
("abcd,c->abd", "...cd,c..."),
|
||||
("abcd,c->abd", "...cd,c...->d..."),
|
||||
("abc,b->abc", "ab...,b...->ab..."),
|
||||
("abc,b->abc", "ab...,...b->ab..."),
|
||||
("abc,b->abc", "ab...,b->ab..."),
|
||||
("ab,bc->ac", "ab...,b...->a..."),
|
||||
("ab,bc->ac", "ab...,...bc->a...c"),
|
||||
("ab,bc->ac", "ab,b...->a..."),
|
||||
("abcdef,defg->abcg", "...def,defg->...g"),
|
||||
]
|
||||
for test_case in tests:
|
||||
inputs = inputs_for_case(test_case[0])
|
||||
np_out = np.einsum(test_case[1], *inputs)
|
||||
mx_out = mx.einsum(test_case[1], *inputs)
|
||||
self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))
|
||||
|
||||
error_tests = [
|
||||
("abc,abc->ab", "a...b...c,a...b...c->abc"),
|
||||
]
|
||||
for test_case in error_tests:
|
||||
inputs = inputs_for_case(test_case[0])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.einsum(test_case[1], *inputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user