mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			26 lines
		
	
	
		
			471 B
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			26 lines
		
	
	
		
			471 B
		
	
	
	
		
			C++
		
	
	
	
	
	
// Copyright © 2024 Apple Inc.
 | 
						|
 | 
						|
#include <mlx/mlx.h>
 | 
						|
#include <iostream>
 | 
						|
 | 
						|
namespace mx = mlx::core;
 | 
						|
 | 
						|
int main() {
 | 
						|
  int batch_size = 8;
 | 
						|
  int input_dim = 32;
 | 
						|
 | 
						|
  // Make the input
 | 
						|
  mx::random::seed(42);
 | 
						|
  auto example_x = mx::random::uniform({batch_size, input_dim});
 | 
						|
 | 
						|
  // Import the function
 | 
						|
  auto forward = mx::import_function("eval_mlp.mlxfn");
 | 
						|
 | 
						|
  // Call the imported function
 | 
						|
  auto out = forward({example_x})[0];
 | 
						|
 | 
						|
  std::cout << out << std::endl;
 | 
						|
 | 
						|
  return 0;
 | 
						|
}
 |