awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

View File

@@ -0,0 +1,66 @@
cmake_minimum_required(VERSION 3.24)
project(mlx_sample_extensions LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------
# Add library
add_library(mlx_ext)
# Add sources
target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers
target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
# ----------------------------- Metal -----------------------------
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib
)
endif()
# ----------------------------- Pybind -----------------------------
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()

View File

@@ -0,0 +1,84 @@
#pragma once
#include "mlx/ops.h"
#include "mlx/primitives.h"
namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
// Operation
///////////////////////////////////////////////////////////////////////////////
/**
* Scale and sum two vectors elementwise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
);
///////////////////////////////////////////////////////////////////////////////
// Primitive
///////////////////////////////////////////////////////////////////////////////
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
/** The Jacobian-vector product. */
array jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself accross
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
};
} // namespace mlx::core

View File

@@ -0,0 +1,61 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T>
[[kernel]] void axpby_general(
device const T* x [[buffer(0)]],
device const T* y [[buffer(1)]],
device T* out [[buffer(2)]],
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const size_t* x_strides [[buffer(6)]],
constant const size_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
template <typename T>
[[kernel]] void axpby_contiguous(
device const T* x [[buffer(0)]],
device const T* y [[buffer(1)]],
device T* out [[buffer(2)]],
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
out[index] =
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
}
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] \
[[kernel]] void axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bflot16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);

View File

@@ -0,0 +1,2 @@
import mlx.core as mx
from .mlx_sample_extensions import *

View File

@@ -0,0 +1,16 @@
from mlx import extension
from setuptools import setup
if __name__ == "__main__":
setup(
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_dir={"": "."},
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
zip_safe=False,
python_requires=">=3.7",
)