MLX in C++ example (#1736)

* MLX in C++ example

* nits

* fix docs
This commit is contained in:
Awni Hannun
2025-01-02 19:09:04 -08:00
committed by GitHub
parent 8544b42007
commit c9d30aa6ac
13 changed files with 242 additions and 31 deletions

View File

@@ -2,20 +2,15 @@ cmake_minimum_required(VERSION 3.27)
project(import_mlx LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# ----------------------------- Dependencies -----------------------------
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m pip show mlx
COMMAND grep location
COMMAND awk "{print $4 \"/mlx\"}"
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)

View File

@@ -3,18 +3,18 @@
#include <mlx/mlx.h>
#include <iostream>
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
// Make the input
random::seed(42);
auto example_x = random::uniform({batch_size, input_dim});
mx::random::seed(42);
auto example_x = mx::random::uniform({batch_size, input_dim});
// Import the function
auto forward = import_function("eval_mlp.mlxfn");
auto forward = mx::import_function("eval_mlp.mlxfn");
// Call the imported function
auto out = forward({example_x})[0];

View File

@@ -3,22 +3,22 @@
#include <mlx/mlx.h>
#include <iostream>
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
int output_dim = 10;
auto state = import_function("init_mlp.mlxfn")({});
auto state = mx::import_function("init_mlp.mlxfn")({});
// Make the input
random::seed(42);
auto example_X = random::normal({batch_size, input_dim});
auto example_y = random::randint(0, output_dim, {batch_size});
mx::random::seed(42);
auto example_X = mx::random::normal({batch_size, input_dim});
auto example_y = mx::random::randint(0, output_dim, {batch_size});
// Import the function
auto step = import_function("train_mlp.mlxfn");
auto step = mx::import_function("train_mlp.mlxfn");
// Call the imported function
for (int it = 0; it < 100; ++it) {