mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			50 lines
		
	
	
		
			849 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			50 lines
		
	
	
		
			849 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import time
 | |
| 
 | |
| import mlx.core as mx
 | |
| 
 | |
| num_features = 100
 | |
| num_examples = 1_000
 | |
| num_iters = 10_000
 | |
| lr = 0.1
 | |
| 
 | |
| # True parameters
 | |
| w_star = mx.random.normal((num_features,))
 | |
| 
 | |
| # Input examples
 | |
| X = mx.random.normal((num_examples, num_features))
 | |
| 
 | |
| # Labels
 | |
| y = (X @ w_star) > 0
 | |
| 
 | |
| 
 | |
| # Initialize random parameters
 | |
| w = 1e-2 * mx.random.normal((num_features,))
 | |
| 
 | |
| 
 | |
| def loss_fn(w):
 | |
|     logits = X @ w
 | |
|     return mx.mean(mx.logaddexp(0.0, logits) - y * logits)
 | |
| 
 | |
| 
 | |
| grad_fn = mx.grad(loss_fn)
 | |
| 
 | |
| tic = time.time()
 | |
| for _ in range(num_iters):
 | |
|     grad = grad_fn(w)
 | |
|     w = w - lr * grad
 | |
|     mx.eval(w)
 | |
| 
 | |
| toc = time.time()
 | |
| 
 | |
| loss = loss_fn(w)
 | |
| final_preds = (X @ w) > 0
 | |
| acc = mx.mean(final_preds == y)
 | |
| 
 | |
| throughput = num_iters / (toc - tic)
 | |
| print(
 | |
|     f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} "
 | |
|     f"Throughput {throughput:.5f} (it/s)"
 | |
| )
 | 
