mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-28 22:28:11 +08:00
awni's commit files
This commit is contained in:
52
examples/cpp/logistic_regression.cpp
Normal file
52
examples/cpp/logistic_regression.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "timer.h"
|
||||
|
||||
/**
|
||||
* An example of logistic regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
int num_examples = 1'000;
|
||||
int num_iters = 10'000;
|
||||
float learning_rate = 0.1;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
|
||||
// The input examples
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
|
||||
// Labels
|
||||
auto y = matmul(X, w_star) > 0;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto logits = matmul(X, w);
|
||||
auto scale = (1.0f / num_examples);
|
||||
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
||||
<< throughput << " (it/s)." << std::endl;
|
||||
}
|
||||
97
examples/cpp/tutorial.cpp
Normal file
97
examples/cpp/tutorial.cpp
Normal file
@@ -0,0 +1,97 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void array_basics() {
|
||||
// Make a scalar array:
|
||||
array x(1.0);
|
||||
|
||||
// Get the value out of it:
|
||||
auto s = x.item<float>();
|
||||
assert(s == 1.0);
|
||||
|
||||
// Scalars have a size of 1:
|
||||
size_t size = x.size();
|
||||
assert(size == 1);
|
||||
|
||||
// Scalars have 0 dimensions:
|
||||
int ndim = x.ndim();
|
||||
assert(ndim == 0);
|
||||
|
||||
// The shape should be an empty vector:
|
||||
auto shape = x.shape();
|
||||
assert(shape.empty());
|
||||
|
||||
// The datatype should be float32:
|
||||
auto dtype = x.dtype();
|
||||
assert(dtype == float32);
|
||||
|
||||
// Specify the dtype when constructing the array:
|
||||
x = array(1, int32);
|
||||
assert(x.dtype() == int32);
|
||||
x.item<int>(); // OK
|
||||
// x.item<float>(); // Undefined!
|
||||
|
||||
// Make a multidimensional array:
|
||||
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
// mlx is row-major by default so the first row of this array
|
||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||
|
||||
// Make an array of shape {2, 2} filled with ones:
|
||||
auto y = ones({2, 2});
|
||||
|
||||
// Pointwise add x and y:
|
||||
auto z = add(x, y);
|
||||
|
||||
// Same thing:
|
||||
z = x + y;
|
||||
|
||||
// mlx is lazy by default. At this point `z` only
|
||||
// has a shape and a type but no actual data:
|
||||
assert(z.dtype() == float32);
|
||||
assert(z.shape(0) == 2);
|
||||
assert(z.shape(1) == 2);
|
||||
|
||||
// To actually run the compuation you must evaluate `z`.
|
||||
// Under the hood, mlx records operations in a graph.
|
||||
// The variable `z` is a node in the graph which points to its operation
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
// all of its dependencies are recursively evaluated to produce the result.
|
||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||
eval(z);
|
||||
|
||||
// Of course the array can still be an input to other operations. You can even
|
||||
// call eval on the array again, this will just be a no-op:
|
||||
eval(z); // no-op
|
||||
|
||||
// Some functions or methods on arrays implicitly evaluate them. For example
|
||||
// accessing a value in an array or printing the array implicitly evaluate it:
|
||||
z = ones({1});
|
||||
z.item<float>(); // implicit evaluation
|
||||
|
||||
z = ones({2, 2});
|
||||
std::cout << z << std::endl; // implicit evaluation
|
||||
}
|
||||
|
||||
void automatic_differentiation() {
|
||||
auto fn = [](array x) { return square(x); };
|
||||
|
||||
// Computing the derivative function of a function
|
||||
auto grad_fn = grad(fn);
|
||||
// Call grad_fn on the input to get the derivative
|
||||
auto x = array(1.5);
|
||||
auto dfdx = grad_fn(x);
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto df2dx2 = grad(grad(fn))(x);
|
||||
// df2dx2 is 2
|
||||
}
|
||||
|
||||
int main() {
|
||||
array_basics();
|
||||
automatic_differentiation();
|
||||
}
|
||||
66
examples/extensions/CMakeLists.txt
Normal file
66
examples/extensions/CMakeLists.txt
Normal file
@@ -0,0 +1,66 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
|
||||
project(mlx_sample_extensions LANGUAGES CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
find_package(Python COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
||||
# Add library
|
||||
add_library(mlx_ext)
|
||||
|
||||
# Add sources
|
||||
target_sources(
|
||||
mlx_ext
|
||||
PUBLIC
|
||||
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
||||
)
|
||||
|
||||
# Add include headers
|
||||
target_include_directories(
|
||||
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
# Link to mlx
|
||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
# ----------------------------- Metal -----------------------------
|
||||
|
||||
# Build metallib
|
||||
if(MLX_BUILD_METAL)
|
||||
|
||||
mlx_build_metallib(
|
||||
TARGET mlx_ext_metallib
|
||||
TITLE mlx_ext
|
||||
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
||||
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
add_dependencies(
|
||||
mlx_ext
|
||||
mlx_ext_metallib
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
# ----------------------------- Pybind -----------------------------
|
||||
pybind11_add_module(
|
||||
mlx_sample_extensions
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
84
examples/extensions/axpby/axpby.h
Normal file
84
examples/extensions/axpby/axpby.h
Normal file
@@ -0,0 +1,84 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class Axpby : public Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
* for the given inputs and populate the output array.
|
||||
*
|
||||
* To avoid unecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself accross
|
||||
* the given axes. The output is a pair containing the array
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<array, int> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
61
examples/extensions/axpby/axpby.metal
Normal file
61
examples/extensions/axpby/axpby.metal
Normal file
@@ -0,0 +1,61 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void axpby_general(
|
||||
device const T* x [[buffer(0)]],
|
||||
device const T* y [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void axpby_contiguous(
|
||||
device const T* x [[buffer(0)]],
|
||||
device const T* y [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||
}
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] \
|
||||
[[kernel]] void axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bflot16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
2
examples/extensions/mlx_sample_extensions/__init__.py
Normal file
2
examples/extensions/mlx_sample_extensions/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
import mlx.core as mx
|
||||
from .mlx_sample_extensions import *
|
||||
16
examples/extensions/setup.py
Normal file
16
examples/extensions/setup.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from mlx import extension
|
||||
from setuptools import setup
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup(
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_dir={"": "."},
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.7",
|
||||
)
|
||||
43
examples/python/linear_regression.py
Normal file
43
examples/python/linear_regression.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import mlx.core as mx
|
||||
import time
|
||||
|
||||
num_features = 100
|
||||
num_examples = 1_000
|
||||
num_iters = 10_000
|
||||
lr = 0.01
|
||||
|
||||
# True parameters
|
||||
w_star = mx.random.normal((num_features,))
|
||||
|
||||
# Input examples (design matrix)
|
||||
X = mx.random.normal((num_examples, num_features))
|
||||
|
||||
# Noisy labels
|
||||
eps = 1e-2 * mx.random.normal((num_examples,))
|
||||
y = X @ w_star + eps
|
||||
|
||||
# Initialize random parameters
|
||||
w = 1e-2 * mx.random.normal((num_features,))
|
||||
|
||||
|
||||
def loss_fn(w):
|
||||
return 0.5 * mx.mean(mx.square(X @ w - y))
|
||||
|
||||
|
||||
grad_fn = mx.grad(loss_fn)
|
||||
|
||||
tic = time.time()
|
||||
for _ in range(num_iters):
|
||||
grad = grad_fn(w)
|
||||
w = w - lr * grad
|
||||
mx.eval(w)
|
||||
toc = time.time()
|
||||
|
||||
loss = loss_fn(w)
|
||||
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
|
||||
throughput = num_iters / (toc - tic)
|
||||
|
||||
print(
|
||||
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
|
||||
f"Throughput {throughput:.5f} (it/s)"
|
||||
)
|
||||
46
examples/python/logistic_regression.py
Normal file
46
examples/python/logistic_regression.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import mlx.core as mx
|
||||
import time
|
||||
|
||||
num_features = 100
|
||||
num_examples = 1_000
|
||||
num_iters = 10_000
|
||||
lr = 0.1
|
||||
|
||||
# True parameters
|
||||
w_star = mx.random.normal((num_features,))
|
||||
|
||||
# Input examples
|
||||
X = mx.random.normal((num_examples, num_features))
|
||||
|
||||
# Labels
|
||||
y = (X @ w_star) > 0
|
||||
|
||||
|
||||
# Initialize random parameters
|
||||
w = 1e-2 * mx.random.normal((num_features,))
|
||||
|
||||
|
||||
def loss_fn(w):
|
||||
logits = X @ w
|
||||
return mx.mean(mx.logaddexp(0.0, logits) - y * logits)
|
||||
|
||||
|
||||
grad_fn = mx.grad(loss_fn)
|
||||
|
||||
tic = time.time()
|
||||
for _ in range(num_iters):
|
||||
grad = grad_fn(w)
|
||||
w = w - lr * grad
|
||||
mx.eval(w)
|
||||
|
||||
toc = time.time()
|
||||
|
||||
loss = loss_fn(w)
|
||||
final_preds = (X @ w) > 0
|
||||
acc = mx.mean(final_preds == y)
|
||||
|
||||
throughput = num_iters / (toc - tic)
|
||||
print(
|
||||
f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} "
|
||||
f"Throughput {throughput:.5f} (it/s)"
|
||||
)
|
||||
Reference in New Issue
Block a user