mlx/benchmarks/cpp/autograd.cpp

40 lines
886 B
C++
Raw Normal View History

2023-12-01 03:12:53 +08:00
// Copyright © 2023 Apple Inc.
2023-11-30 02:52:08 +08:00
#include <iostream>
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
2023-11-30 02:52:08 +08:00
void time_value_and_grad() {
auto x = mx::ones({200, 1000});
mx::eval(x);
2025-04-23 21:49:04 +08:00
auto fn = [](mx::x) {
2023-11-30 02:52:08 +08:00
for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x));
2023-11-30 02:52:08 +08:00
}
return mx::sum(x);
2023-11-30 02:52:08 +08:00
};
auto grad_fn = mx::grad(fn);
2023-11-30 02:52:08 +08:00
auto independent_value_and_grad = [&]() {
auto value = fn(x);
auto dfdx = grad_fn(x);
return std::vector<mx::array>{value, dfdx};
2023-11-30 02:52:08 +08:00
};
TIME(independent_value_and_grad);
auto value_and_grad_fn = mx::value_and_grad(fn);
2023-11-30 02:52:08 +08:00
auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<mx::array>{value, dfdx};
2023-11-30 02:52:08 +08:00
};
TIME(combined_value_and_grad);
}
int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
2023-11-30 02:52:08 +08:00
time_value_and_grad();
}