mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	* einsum initial * fix comma break * sum axis was wrong * small cleanups * python binding * changed bindings to resemble numpy * remove todo comment * comment changes * add count of operands/inputs * fail fast if operands list is empty * ignore comma if no output * einsum path matching numpy * getting somewhere with path * remove print * it passes the first test * moved einsum tests to seperate file * seperated einsum path * moved einsum naive * remove space from equation * fast fail if no operands passed * update tests and remove printf * small cleanup * some more cleanups * removed python helper file * ack * utilize std for finding min in vector * duplicate def * remove the tuple as it was unreadable * moved einsum_naive back to ops * remaining isn't needed * avoid creating another set * cleanup * greedy path, start of naive einsum * more einsum * fix some bugs * some more fixes, tests pass * benchmark * some simplify * fix einsum and test Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> * add a bunch more tests and fix a bunch more bugs * some docs nits --------- Co-authored-by: dc-dc-dc <dgcruz983@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
		
			
				
	
	
		
			77 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
// Copyright © 2024 Apple Inc.
 | 
						|
 | 
						|
#include "doctest/doctest.h"
 | 
						|
#include "mlx/mlx.h"
 | 
						|
 | 
						|
using namespace mlx::core;
 | 
						|
 | 
						|
TEST_CASE("test einsum path") {
 | 
						|
  std::vector<std::vector<int>> expected = {{1, 2}, {0, 1}};
 | 
						|
  auto path =
 | 
						|
      einsum_path("ij,jk,kl", {ones({2, 2}), ones({2, 4}), ones({4, 2})}).first;
 | 
						|
  CHECK_EQ(path, expected);
 | 
						|
 | 
						|
  expected = {{0}};
 | 
						|
  path = einsum_path("jki", {ones({2, 3, 4})}).first;
 | 
						|
  CHECK_EQ(path, expected);
 | 
						|
 | 
						|
  expected = {{0, 1}};
 | 
						|
  path = einsum_path("i,i", {ones({2}), ones({1})}).first;
 | 
						|
  CHECK_EQ(path, expected);
 | 
						|
 | 
						|
  expected = {{0, 1}};
 | 
						|
  path = einsum_path("ij,jk", {ones({2, 2}), ones({2, 2})}).first;
 | 
						|
  CHECK_EQ(path, expected);
 | 
						|
 | 
						|
  expected = {{0, 1}};
 | 
						|
  path = einsum_path("ijk,jil->kl", {ones({3, 4, 5}), ones({4, 3, 2})}).first;
 | 
						|
  CHECK_EQ(path, expected);
 | 
						|
 | 
						|
  expected = {{0, 3}, {1, 3}, {0, 2}, {0, 1}};
 | 
						|
  path = einsum_path(
 | 
						|
             "ijk,ilm,njm,nlk,abc->",
 | 
						|
             {ones({2, 6, 8}),
 | 
						|
              ones({2, 4, 5}),
 | 
						|
              ones({3, 6, 5}),
 | 
						|
              ones({3, 4, 8}),
 | 
						|
              ones({9, 4, 7})})
 | 
						|
             .first;
 | 
						|
  CHECK_EQ(path, expected);
 | 
						|
 | 
						|
  expected = {{0, 2}, {0, 3}, {0, 2}, {0, 1}};
 | 
						|
  path = einsum_path(
 | 
						|
             "ea,fb,abcd,gc,hd->efgh",
 | 
						|
             {ones({10, 10}),
 | 
						|
              ones({10, 10}),
 | 
						|
              ones({10, 10, 10, 10}),
 | 
						|
              ones({10, 10}),
 | 
						|
              ones({10, 10})})
 | 
						|
             .first;
 | 
						|
  CHECK_EQ(path, expected);
 | 
						|
}
 | 
						|
 | 
						|
TEST_CASE("test einsum") {
 | 
						|
  CHECK_THROWS(einsum("i,j", {array({1.0})}));
 | 
						|
  CHECK_THROWS(einsum("ijk", {full({2, 2}, 2.0f)}));
 | 
						|
  CHECK_THROWS(einsum("", {}));
 | 
						|
  CHECK_THROWS(einsum("ij", {array({1, 2})}));
 | 
						|
  CHECK_THROWS(einsum("", {array({1, 2})}));
 | 
						|
  CHECK_THROWS(einsum("i,ij", {array({1, 2}), array({2, 3})}));
 | 
						|
  CHECK_THROWS(einsum("i,i", {array({1, 2}), array({2, 3, 4})}));
 | 
						|
  CHECK_THROWS(einsum("i->ii", {array({1, 2})}));
 | 
						|
  CHECK_THROWS(einsum("12", {zeros({4, 4})}));
 | 
						|
  CHECK_THROWS(einsum("ii->i", {zeros({3, 2})}));
 | 
						|
 | 
						|
  auto x = einsum("jki", {full({2, 3, 4}, 3.0f)});
 | 
						|
  auto expected = full({4, 2, 3}, 3.0f);
 | 
						|
  CHECK_EQ(allclose(x, expected).item<bool>(), true);
 | 
						|
 | 
						|
  x = einsum("ij,jk->ik", {full({2, 2}, 2.0f), full({2, 2}, 3.0f)});
 | 
						|
  expected = array({12.0f, 12.0f, 12.0f, 12.0f}, {2, 2});
 | 
						|
  CHECK_EQ(allclose(x, expected).item<bool>(), true);
 | 
						|
 | 
						|
  x = einsum("i,j->ij", {full({2}, 15.0f), full({4}, 20.0f)});
 | 
						|
  expected = full({2, 4}, 300.0f);
 | 
						|
  CHECK_EQ(allclose(x, expected).item<bool>(), true);
 | 
						|
}
 |