mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			40 lines
		
	
	
		
			849 B
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			40 lines
		
	
	
		
			849 B
		
	
	
	
		
			C++
		
	
	
	
	
	
// Copyright © 2023 Apple Inc.
 | 
						|
 | 
						|
#include <iostream>
 | 
						|
 | 
						|
#include "mlx/mlx.h"
 | 
						|
#include "time_utils.h"
 | 
						|
 | 
						|
using namespace mlx::core;
 | 
						|
 | 
						|
void time_value_and_grad() {
 | 
						|
  auto x = ones({200, 1000});
 | 
						|
  eval(x);
 | 
						|
  auto fn = [](array x) {
 | 
						|
    for (int i = 0; i < 20; ++i) {
 | 
						|
      x = log(exp(x));
 | 
						|
    }
 | 
						|
    return sum(x);
 | 
						|
  };
 | 
						|
 | 
						|
  auto grad_fn = grad(fn);
 | 
						|
  auto independent_value_and_grad = [&]() {
 | 
						|
    auto value = fn(x);
 | 
						|
    auto dfdx = grad_fn(x);
 | 
						|
    return std::vector<array>{value, dfdx};
 | 
						|
  };
 | 
						|
  TIME(independent_value_and_grad);
 | 
						|
 | 
						|
  auto value_and_grad_fn = value_and_grad(fn);
 | 
						|
  auto combined_value_and_grad = [&]() {
 | 
						|
    auto [value, dfdx] = value_and_grad_fn(x);
 | 
						|
    return std::vector<array>{value, dfdx};
 | 
						|
  };
 | 
						|
  TIME(combined_value_and_grad);
 | 
						|
}
 | 
						|
 | 
						|
int main() {
 | 
						|
  std::cout << "Benchmarks for " << default_device() << std::endl;
 | 
						|
  time_value_and_grad();
 | 
						|
}
 |