mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	more cpp tests
This commit is contained in:
		@@ -1,8 +1,6 @@
 | 
				
			|||||||
// Copyright © 2023 Apple Inc.
 | 
					// Copyright © 2023 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <iostream> // TODO
 | 
					 | 
				
			||||||
#include "doctest/doctest.h"
 | 
					#include "doctest/doctest.h"
 | 
				
			||||||
#include "mlx/utils.h" // TODO
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/mlx.h"
 | 
					#include "mlx/mlx.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -65,20 +63,36 @@ TEST_CASE("test compile inputs with primitive") {
 | 
				
			|||||||
  CHECK(array_equal(expected, out).item<bool>());
 | 
					  CHECK(array_equal(expected, out).item<bool>());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/*std::vector<array> bigger_fun(const std::vector<array>& inputs) {
 | 
					std::vector<array> fun_creats_array(const std::vector<array>& inputs) {
 | 
				
			||||||
  auto x = inputs[1];
 | 
					  return {inputs[0] + array(1.0f)};
 | 
				
			||||||
  for (int i = 1; i < inputs.size(); ++i) {
 | 
					 | 
				
			||||||
    w = inputs[i]
 | 
					 | 
				
			||||||
    x = maximum(matmul(x, w), 0);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return take(x, array(3)) - logsumexp(x);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_CASE("test bigger graph") {
 | 
					TEST_CASE("test compile with created array") {
 | 
				
			||||||
  std::vector<array> inputs;
 | 
					  auto cfun = compile(fun_creats_array);
 | 
				
			||||||
  inputs.push_back(
 | 
					  auto out = cfun({array(2.0f)});
 | 
				
			||||||
  for (int
 | 
					  CHECK_EQ(out[0].item<float>(), 3.0f);
 | 
				
			||||||
  for
 | 
					 | 
				
			||||||
}*/
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_CASE("test nested compile") {}
 | 
					  // Try again
 | 
				
			||||||
 | 
					  out = cfun({array(2.0f)});
 | 
				
			||||||
 | 
					  CHECK_EQ(out[0].item<float>(), 3.0f);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::vector<array> inner_fun(const std::vector<array>& inputs) {
 | 
				
			||||||
 | 
					  return {array(2) * inputs[0]};
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::vector<array> outer_fun(const std::vector<array>& inputs) {
 | 
				
			||||||
 | 
					  auto x = inputs[0] + inputs[1];
 | 
				
			||||||
 | 
					  auto y = compile(inner_fun)({x})[0];
 | 
				
			||||||
 | 
					  return {x + y};
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_CASE("test nested compile") {
 | 
				
			||||||
 | 
					  auto cfun = compile(outer_fun);
 | 
				
			||||||
 | 
					  auto out = cfun({array(1), array(2)})[0];
 | 
				
			||||||
 | 
					  CHECK_EQ(out.item<int>(), 9);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Try again
 | 
				
			||||||
 | 
					  out = cfun({array(1), array(2)})[0];
 | 
				
			||||||
 | 
					  CHECK_EQ(out.item<int>(), 9);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user