2023-12-01 03:12:53 +08:00
|
|
|
// Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
#include <chrono>
|
|
|
|
#include <cmath>
|
|
|
|
#include <iostream>
|
|
|
|
|
|
|
|
#include "mlx/mlx.h"
|
|
|
|
#include "timer.h"
|
|
|
|
|
|
|
|
/**
|
|
|
|
* An example of logistic regression with MLX.
|
|
|
|
*/
|
2024-12-11 23:08:29 +08:00
|
|
|
namespace mx = mlx::core;
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
int main() {
|
|
|
|
int num_features = 100;
|
|
|
|
int num_examples = 1'000;
|
|
|
|
int num_iters = 10'000;
|
|
|
|
float learning_rate = 0.1;
|
|
|
|
|
|
|
|
// True parameters
|
2024-12-11 23:08:29 +08:00
|
|
|
auto w_star = mx::random::normal({num_features});
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
// The input examples
|
2024-12-11 23:08:29 +08:00
|
|
|
auto X = mx::random::normal({num_examples, num_features});
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
// Labels
|
2024-12-11 23:08:29 +08:00
|
|
|
auto y = mx::matmul(X, w_star) > 0;
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
// Initialize random parameters
|
2024-12-11 23:08:29 +08:00
|
|
|
mx::array w = 1e-2 * mx::random::normal({num_features});
|
2023-11-30 02:30:41 +08:00
|
|
|
|
2024-12-11 23:08:29 +08:00
|
|
|
auto loss_fn = [&](mx::array w) {
|
|
|
|
auto logits = mx::matmul(X, w);
|
2023-11-30 02:30:41 +08:00
|
|
|
auto scale = (1.0f / num_examples);
|
2024-12-11 23:08:29 +08:00
|
|
|
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
|
2023-11-30 02:30:41 +08:00
|
|
|
};
|
|
|
|
|
2024-12-11 23:08:29 +08:00
|
|
|
auto grad_fn = mx::grad(loss_fn);
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
auto tic = timer::time();
|
|
|
|
for (int it = 0; it < num_iters; ++it) {
|
2024-12-11 23:08:29 +08:00
|
|
|
auto grads = grad_fn(w);
|
|
|
|
w = w - learning_rate * grads;
|
|
|
|
mx::eval(w);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
auto toc = timer::time();
|
|
|
|
|
|
|
|
auto loss = loss_fn(w);
|
2024-12-11 23:08:29 +08:00
|
|
|
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
|
2023-11-30 02:30:41 +08:00
|
|
|
auto throughput = num_iters / timer::seconds(toc - tic);
|
|
|
|
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
|
|
|
<< throughput << " (it/s)." << std::endl;
|
|
|
|
}
|