mlx/examples/python/logistic_regression.py
2023-11-30 11:12:53 -08:00

49 lines
848 B
Python

# Copyright © 2023 Apple Inc.
import mlx.core as mx
import time
num_features = 100
num_examples = 1_000
num_iters = 10_000
lr = 0.1
# True parameters
w_star = mx.random.normal((num_features,))
# Input examples
X = mx.random.normal((num_examples, num_features))
# Labels
y = (X @ w_star) > 0
# Initialize random parameters
w = 1e-2 * mx.random.normal((num_features,))
def loss_fn(w):
logits = X @ w
return mx.mean(mx.logaddexp(0.0, logits) - y * logits)
grad_fn = mx.grad(loss_fn)
tic = time.time()
for _ in range(num_iters):
grad = grad_fn(w)
w = w - lr * grad
mx.eval(w)
toc = time.time()
loss = loss_fn(w)
final_preds = (X @ w) > 0
acc = mx.mean(final_preds == y)
throughput = num_iters / (toc - tic)
print(
f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} "
f"Throughput {throughput:.5f} (it/s)"
)