mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
26 lines
460 B
C++
26 lines
460 B
C++
![]() |
// Copyright © 2024 Apple Inc.
|
||
|
|
||
|
#include <mlx/mlx.h>
|
||
|
#include <iostream>
|
||
|
|
||
|
using namespace mlx::core;
|
||
|
|
||
|
int main() {
|
||
|
int batch_size = 8;
|
||
|
int input_dim = 32;
|
||
|
|
||
|
// Make the input
|
||
|
random::seed(42);
|
||
|
auto example_x = random::uniform({batch_size, input_dim});
|
||
|
|
||
|
// Import the function
|
||
|
auto forward = import_function("eval_mlp.mlxfn");
|
||
|
|
||
|
// Call the imported function
|
||
|
auto out = forward({example_x})[0];
|
||
|
|
||
|
std::cout << out << std::endl;
|
||
|
|
||
|
return 0;
|
||
|
}
|