mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
jagrit's commit files
This commit is contained in:
10
examples/cpp/CMakeLists.txt
Normal file
10
examples/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
function(build_example SRCFILE)
|
||||
get_filename_component(src_name ${SRCFILE} NAME_WE)
|
||||
set(target "${src_name}")
|
||||
add_executable(${target} ${SRCFILE})
|
||||
target_link_libraries(${target} PRIVATE mlx)
|
||||
endfunction(build_example)
|
||||
|
||||
build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
52
examples/cpp/linear_regression.cpp
Normal file
52
examples/cpp/linear_regression.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "timer.h"
|
||||
|
||||
/**
|
||||
* An example of linear regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
int num_examples = 1'000;
|
||||
int num_iters = 10'000;
|
||||
float learning_rate = 0.01;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
|
||||
// The input examples (design matrix)
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
|
||||
// Noisy labels
|
||||
auto eps = 1e-2 * random::normal({num_examples});
|
||||
auto y = matmul(X, w_star) + eps;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto yhat = matmul(X, w);
|
||||
return (0.5f / num_examples) * sum(square(yhat - y));
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||
}
|
Reference in New Issue
Block a user