mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
50 lines
849 B
Python
50 lines
849 B
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import time
|
|
|
|
import mlx.core as mx
|
|
|
|
num_features = 100
|
|
num_examples = 1_000
|
|
num_iters = 10_000
|
|
lr = 0.1
|
|
|
|
# True parameters
|
|
w_star = mx.random.normal((num_features,))
|
|
|
|
# Input examples
|
|
X = mx.random.normal((num_examples, num_features))
|
|
|
|
# Labels
|
|
y = (X @ w_star) > 0
|
|
|
|
|
|
# Initialize random parameters
|
|
w = 1e-2 * mx.random.normal((num_features,))
|
|
|
|
|
|
def loss_fn(w):
|
|
logits = X @ w
|
|
return mx.mean(mx.logaddexp(0.0, logits) - y * logits)
|
|
|
|
|
|
grad_fn = mx.grad(loss_fn)
|
|
|
|
tic = time.time()
|
|
for _ in range(num_iters):
|
|
grad = grad_fn(w)
|
|
w = w - lr * grad
|
|
mx.eval(w)
|
|
|
|
toc = time.time()
|
|
|
|
loss = loss_fn(w)
|
|
final_preds = (X @ w) > 0
|
|
acc = mx.mean(final_preds == y)
|
|
|
|
throughput = num_iters / (toc - tic)
|
|
print(
|
|
f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} "
|
|
f"Throughput {throughput:.5f} (it/s)"
|
|
)
|