| 
									
										
										
										
											2023-11-30 11:12:53 -08:00
										 |  |  | // Copyright © 2023 Apple Inc.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | #include <chrono>
 | 
					
						
							|  |  |  | #include <cmath>
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "mlx/mlx.h"
 | 
					
						
							|  |  |  | #include "timer.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /**
 | 
					
						
							|  |  |  |  * An example of logistic regression with MLX. | 
					
						
							|  |  |  |  */ | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  | namespace mx = mlx::core; | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | int main() { | 
					
						
							|  |  |  |   int num_features = 100; | 
					
						
							|  |  |  |   int num_examples = 1'000; | 
					
						
							|  |  |  |   int num_iters = 10'000; | 
					
						
							|  |  |  |   float learning_rate = 0.1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // True parameters
 | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  |   auto w_star = mx::random::normal({num_features}); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   // The input examples
 | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  |   auto X = mx::random::normal({num_examples, num_features}); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   // Labels
 | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  |   auto y = mx::matmul(X, w_star) > 0; | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   // Initialize random parameters
 | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  |   mx::array w = 1e-2 * mx::random::normal({num_features}); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  |   auto loss_fn = [&](mx::array w) { | 
					
						
							|  |  |  |     auto logits = mx::matmul(X, w); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  |     auto scale = (1.0f / num_examples); | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  |     return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  |   auto grad_fn = mx::grad(loss_fn); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   auto tic = timer::time(); | 
					
						
							|  |  |  |   for (int it = 0; it < num_iters; ++it) { | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  |     auto grads = grad_fn(w); | 
					
						
							|  |  |  |     w = w - learning_rate * grads; | 
					
						
							|  |  |  |     mx::eval(w); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  |   } | 
					
						
							|  |  |  |   auto toc = timer::time(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   auto loss = loss_fn(w); | 
					
						
							| 
									
										
										
										
											2024-12-12 00:08:29 +09:00
										 |  |  |   auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples; | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  |   auto throughput = num_iters / timer::seconds(toc - tic); | 
					
						
							|  |  |  |   std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " | 
					
						
							|  |  |  |             << throughput << " (it/s)." << std::endl; | 
					
						
							|  |  |  | } |