mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-09 02:24:34 +08:00
Compare commits
12 Commits
packed-qua
...
socket-dis
Author | SHA1 | Date | |
---|---|---|---|
![]() |
d82699f0f1 | ||
![]() |
6fc00d2c10 | ||
![]() |
44f0de2854 | ||
![]() |
29ec3539ed | ||
![]() |
e94f0028c3 | ||
![]() |
e5354fcddb | ||
![]() |
34dd079a64 | ||
![]() |
16975815e9 | ||
![]() |
a8b3da7946 | ||
![]() |
060e1c9f92 | ||
![]() |
0b04742985 | ||
![]() |
c3ccd4919f |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -76,9 +76,6 @@ build/
|
||||
*.out
|
||||
*.app
|
||||
|
||||
# Debug symbols
|
||||
*.pdb
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
.DS_Store
|
||||
|
@@ -1,14 +1,13 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v19.1.4
|
||||
rev: v18.1.8
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.10.0
|
||||
rev: 24.8.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
|
@@ -20,12 +20,11 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.21.1)
|
||||
set(MLX_VERSION 0.19.3)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -35,6 +34,8 @@ message(
|
||||
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||
)
|
||||
|
||||
set(MLX_BUILD_ARM OFF)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
if(NOT MLX_ENABLE_X64_MAC)
|
||||
@@ -56,6 +57,10 @@ else()
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
|
||||
include(FetchContent)
|
||||
@@ -84,26 +89,25 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
if(${MACOS_VERSION} LESS 14.0)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||
endif()
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||
)
|
||||
# Get the metal version
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
@@ -111,54 +115,20 @@ elseif(MLX_BUILD_METAL)
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
# There is no prebuilt OpenBLAS distribution for MSVC.
|
||||
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
|
||||
endif()
|
||||
# Windows implementation of dlfcn.h APIs.
|
||||
FetchContent_Declare(
|
||||
dlfcn-win32
|
||||
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||
GIT_TAG v1.4.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(dlfcn-win32)
|
||||
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
||||
target_link_libraries(mlx PRIVATE dl)
|
||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if(ACCELERATE_LIBRARY)
|
||||
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||
# Download and build OpenBLAS from source code.
|
||||
FetchContent_Declare(
|
||||
openblas
|
||||
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
|
||||
GIT_TAG v0.3.28
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(BUILD_STATIC_LIBS ON) # link statically
|
||||
set(NOFORTRAN ON) # msvc has no fortran compiler
|
||||
FetchContent_MakeAvailable(openblas)
|
||||
target_link_libraries(mlx PRIVATE openblas)
|
||||
target_include_directories(
|
||||
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
|
||||
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
|
||||
else()
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
@@ -176,7 +146,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old
|
||||
# version of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
@@ -189,7 +159,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
|
@@ -5,35 +5,35 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_value_and_grad() {
|
||||
auto x = mx::ones({200, 1000});
|
||||
mx::eval(x);
|
||||
auto fn = [](mx::array x) {
|
||||
auto x = ones({200, 1000});
|
||||
eval(x);
|
||||
auto fn = [](array x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = mx::log(mx::exp(x));
|
||||
x = log(exp(x));
|
||||
}
|
||||
return mx::sum(x);
|
||||
return sum(x);
|
||||
};
|
||||
|
||||
auto grad_fn = mx::grad(fn);
|
||||
auto grad_fn = grad(fn);
|
||||
auto independent_value_and_grad = [&]() {
|
||||
auto value = fn(x);
|
||||
auto dfdx = grad_fn(x);
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
return std::vector<array>{value, dfdx};
|
||||
};
|
||||
TIME(independent_value_and_grad);
|
||||
|
||||
auto value_and_grad_fn = mx::value_and_grad(fn);
|
||||
auto value_and_grad_fn = value_and_grad(fn);
|
||||
auto combined_value_and_grad = [&]() {
|
||||
auto [value, dfdx] = value_and_grad_fn(x);
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
return std::vector<array>{value, dfdx};
|
||||
};
|
||||
TIME(combined_value_and_grad);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_value_and_grad();
|
||||
}
|
||||
|
@@ -4,21 +4,21 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_add_op() {
|
||||
std::vector<int> sizes(1, 1);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
sizes.push_back(10 * sizes.back());
|
||||
}
|
||||
set_default_device(mx::Device::cpu);
|
||||
set_default_device(Device::cpu);
|
||||
for (auto size : sizes) {
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
std::cout << "Size " << size << std::endl;
|
||||
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
||||
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
||||
TIMEM("cpu", add, a, b, Device::cpu);
|
||||
TIMEM("gpu", add, a, b, Device::gpu);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -6,105 +6,105 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_irregular_binary_ops_1D() {
|
||||
auto device = mx::default_device();
|
||||
auto device = default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
b = slice(b, {0}, {size}, {step});
|
||||
TIMEM("1D strided", mx::add, a, b, device);
|
||||
TIMEM("1D strided", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_2D() {
|
||||
auto device = mx::default_device();
|
||||
auto device = default_device();
|
||||
int size = 2048;
|
||||
auto a = mx::random::uniform({size, size});
|
||||
auto b = mx::random::uniform({size, size});
|
||||
mx::eval(a, b);
|
||||
TIMEM("2D regular", mx::add, a, b, device);
|
||||
auto a = random::uniform({size, size});
|
||||
auto b = random::uniform({size, size});
|
||||
eval(a, b);
|
||||
TIMEM("2D regular", add, a, b, device);
|
||||
|
||||
b = mx::transpose(b);
|
||||
mx::eval(b);
|
||||
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
||||
b = transpose(b);
|
||||
eval(b);
|
||||
TIMEM("2D transpose", add, a, b, device);
|
||||
|
||||
b = mx::random::uniform({size});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
||||
b = random::uniform({size});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||
|
||||
b = mx::reshape(b, {size, 1});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
||||
b = reshape(b, {size, 1});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_3D() {
|
||||
auto device = mx::default_device();
|
||||
auto device = default_device();
|
||||
int d0 = 32;
|
||||
int d1 = 512;
|
||||
int d2 = 512;
|
||||
auto a = mx::random::uniform({d0, d1, d2});
|
||||
auto b = mx::random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", mx::add, a, b, device);
|
||||
auto a = random::uniform({d0, d1, d2});
|
||||
auto b = random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", add, a, b, device);
|
||||
|
||||
b = mx::transpose(b, {0, 2, 1});
|
||||
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
||||
b = transpose(b, {0, 2, 1});
|
||||
TIMEM("3D transpose", add, a, b, device);
|
||||
|
||||
b = mx::random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
||||
b = random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||
|
||||
b = mx::random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
||||
b = random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||
|
||||
b = mx::random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
||||
b = random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||
|
||||
b = mx::random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
||||
b = random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||
|
||||
b = mx::random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
||||
b = random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||
|
||||
b = mx::random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
||||
b = random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = mx::default_device();
|
||||
auto device = default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
auto a = random::uniform(shape);
|
||||
auto b = random::uniform(shape);
|
||||
|
||||
TIMEM("4D regular", mx::add, a, b, device);
|
||||
TIMEM("4D regular", add, a, b, device);
|
||||
|
||||
b = mx::transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
||||
b = transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D transpose", add, a, b, device);
|
||||
|
||||
std::string om = "4D broadcast dims ";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = 1;
|
||||
b = mx::random::uniform(shape);
|
||||
b = random::uniform(shape);
|
||||
std::ostringstream msg;
|
||||
msg << om << i;
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
|
||||
for (int j = i + 1; j < shape.size(); ++j) {
|
||||
shape[j] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j;
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
shape[j] = a.shape(j);
|
||||
|
||||
for (int k = j + 1; k < shape.size(); ++k) {
|
||||
shape[k] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j << ", " << k;
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
shape[k] = a.shape(k);
|
||||
}
|
||||
}
|
||||
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
|
||||
}
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = mx::default_device();
|
||||
auto device = default_device();
|
||||
std::vector<int> shape;
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
auto reshape_fn = [&shape, device](const array& a) {
|
||||
return reshape(a, shape, device);
|
||||
};
|
||||
|
||||
int size = 64;
|
||||
int d = 2 * size;
|
||||
|
||||
auto a = mx::random::uniform({d, d, d});
|
||||
auto a = random::uniform({d, d, d});
|
||||
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D contiguous", reshape_fn, a);
|
||||
|
||||
a = mx::transpose(a);
|
||||
a = transpose(a);
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D mx::transpose", reshape_fn, a);
|
||||
TIMEM("3D transpose", reshape_fn, a);
|
||||
|
||||
a = mx::transpose(a, {1, 2, 0});
|
||||
a = transpose(a, {1, 2, 0});
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||
|
||||
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||
|
||||
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||
}
|
||||
|
||||
void time_irregular_astype_1D() {
|
||||
auto device = mx::default_device();
|
||||
auto device = default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = mx::random::uniform({size});
|
||||
auto a = random::uniform({size});
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
||||
TIMEM("1D strided", astype, a, int32, device);
|
||||
}
|
||||
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = mx::default_device();
|
||||
auto device = default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
auto a = random::uniform(shape);
|
||||
TIMEM("2D regular", astype, a, int32, device);
|
||||
|
||||
a = mx::transpose(a);
|
||||
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
||||
a = transpose(a);
|
||||
TIMEM("2D transpose", astype, a, int32, device);
|
||||
|
||||
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
||||
a = broadcast_to(random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||
|
||||
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 1) {
|
||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||
}
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_irregular_binary_ops_1D();
|
||||
time_irregular_binary_ops_2D();
|
||||
time_irregular_binary_ops_3D();
|
||||
|
@@ -3,20 +3,20 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_creation_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||
TIME(full_fp32);
|
||||
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||
TIME(zeros_fp32);
|
||||
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||
TIME(ones_fp32);
|
||||
|
||||
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||
TIME(arange_fp32);
|
||||
}
|
||||
|
||||
@@ -24,196 +24,194 @@ void time_type_conversions() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto device = mx::default_device();
|
||||
auto device = default_device();
|
||||
|
||||
auto a = mx::zeros(shape, mx::float32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
auto a = zeros(shape, float32);
|
||||
eval(a);
|
||||
TIMEM("float32 to int32", astype, a, int32, device);
|
||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||
|
||||
a = mx::zeros(shape, mx::int32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
||||
a = zeros(shape, int32);
|
||||
eval(a);
|
||||
TIMEM("int32 to float32", astype, a, float32, device);
|
||||
|
||||
a = mx::zeros(shape, mx::bool_);
|
||||
mx::eval(a);
|
||||
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
||||
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
a = zeros(shape, bool_);
|
||||
eval(a);
|
||||
TIMEM("bool to float32", astype, a, float32, device);
|
||||
TIMEM("bool to int32", astype, a, int32, device);
|
||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||
}
|
||||
|
||||
void time_random_generation() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
|
||||
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||
TIME(uniform);
|
||||
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||
TIME(normal);
|
||||
}
|
||||
|
||||
void time_unary_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto device = mx::default_device();
|
||||
auto device = default_device();
|
||||
|
||||
auto a = mx::random::normal({M, N});
|
||||
mx::eval(a);
|
||||
auto a = random::normal({M, N});
|
||||
eval(a);
|
||||
TIME(mlx::core::abs, a, device);
|
||||
TIME(mx::negative, a, device);
|
||||
TIME(mx::sign, a, device);
|
||||
TIME(mx::square, a, device);
|
||||
TIME(negative, a, device);
|
||||
TIME(sign, a, device);
|
||||
TIME(square, a, device);
|
||||
TIME(mlx::core::sqrt, a, device);
|
||||
TIME(mx::rsqrt, a, device);
|
||||
TIME(rsqrt, a, device);
|
||||
TIME(mlx::core::exp, a, device);
|
||||
|
||||
a = mx::random::uniform({M, N});
|
||||
a = random::uniform({M, N});
|
||||
TIME(mlx::core::log, a, device);
|
||||
}
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto condition = mx::random::randint(0, 2, {M, N, K});
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
auto condition = random::randint(0, 2, {M, N, K});
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
|
||||
TIME(mx::add, a, b, device);
|
||||
TIME(mx::subtract, a, b, device);
|
||||
TIME(mx::multiply, a, b, device);
|
||||
TIME(mx::divide, a, b, device);
|
||||
TIME(mx::maximum, a, b, device);
|
||||
TIME(mx::minimum, a, b, device);
|
||||
TIME(mx::where, condition, a, b, device);
|
||||
TIME(add, a, b, device);
|
||||
TIME(subtract, a, b, device);
|
||||
TIME(multiply, a, b, device);
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
TIME(where, condition, a, b, device);
|
||||
|
||||
condition = mx::array({true});
|
||||
b = mx::random::uniform({1});
|
||||
mx::eval(b);
|
||||
TIMEM("scalar", mx::add, a, b, device);
|
||||
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
||||
TIMEM("scalar", mx::multiply, a, b, device);
|
||||
TIMEM("vector-scalar", mx::divide, a, b, device);
|
||||
TIMEM("scalar-vector", mx::divide, b, a, device);
|
||||
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
||||
condition = array({true});
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
TIMEM("vector-scalar", subtract, a, b, device);
|
||||
TIMEM("scalar-vector", subtract, b, a, device);
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
TIMEM("scalar-vector", where, condition, a, b, device);
|
||||
|
||||
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
||||
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
mx::eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
||||
condition = broadcast_to(array({true}), {1000, 100});
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
int M = 50, N = 50, O = 50, P = 50;
|
||||
auto a = mx::random::uniform({M, N, O, P});
|
||||
auto b = mx::random::uniform({M, N, O, P});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIMEM("non-strided", mx::add, a, b, device);
|
||||
a = mx::transpose(a, {1, 0, 2, 3});
|
||||
b = mx::transpose(b, {3, 2, 0, 1});
|
||||
mx::eval(a, b);
|
||||
TIMEM("strided", mx::add, a, b, device);
|
||||
auto a = random::uniform({M, N, O, P});
|
||||
auto b = random::uniform({M, N, O, P});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIMEM("non-strided", add, a, b, device);
|
||||
a = transpose(a, {1, 0, 2, 3});
|
||||
b = transpose(b, {3, 2, 0, 1});
|
||||
eval(a, b);
|
||||
TIMEM("strided", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_comparisons() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::equal, a, b, device);
|
||||
TIME(mx::greater, a, b, device);
|
||||
TIME(mx::greater_equal, a, b, device);
|
||||
TIME(mx::less, a, b, device);
|
||||
TIME(mx::less_equal, a, b, device);
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(equal, a, b, device);
|
||||
TIME(greater, a, b, device);
|
||||
TIME(greater_equal, a, b, device);
|
||||
TIME(less, a, b, device);
|
||||
TIME(less_equal, a, b, device);
|
||||
}
|
||||
|
||||
void time_matvec() {
|
||||
int M = 2000, N = 200;
|
||||
auto a = mx::random::uniform({M, N});
|
||||
auto b = mx::random::uniform({N});
|
||||
auto c = mx::random::uniform({M});
|
||||
mx::eval(a, b, c);
|
||||
auto matvec = [&]() { return mx::matmul(a, b); };
|
||||
auto a = random::uniform({M, N});
|
||||
auto b = random::uniform({N});
|
||||
auto c = random::uniform({M});
|
||||
eval(a, b, c);
|
||||
auto matvec = [&]() { return matmul(a, b); };
|
||||
TIME(matvec);
|
||||
|
||||
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||
TIME(matvec_transpose);
|
||||
}
|
||||
|
||||
void time_matmul() {
|
||||
int M = 1000, N = 1000, K = 1000;
|
||||
auto a = mx::random::uniform({M, K});
|
||||
auto b = mx::random::uniform({K, N});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::matmul, a, b, device);
|
||||
auto a = random::uniform({M, K});
|
||||
auto b = random::uniform({K, N});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(matmul, a, b, device);
|
||||
|
||||
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||
TIME(transpose_matmul);
|
||||
}
|
||||
|
||||
void time_reductions() {
|
||||
auto a = mx::random::normal({10000, 1000});
|
||||
mx::eval(a);
|
||||
auto sum_all = [&a]() { return mx::sum(a, false); };
|
||||
auto a = random::normal({10000, 1000});
|
||||
eval(a);
|
||||
auto sum_all = [&a]() { return sum(a, false); };
|
||||
TIME(sum_all);
|
||||
|
||||
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||
TIME(sum_along_0);
|
||||
|
||||
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||
TIME(sum_along_1);
|
||||
|
||||
auto prod_all = [&a]() { return mx::prod(a, false); };
|
||||
auto prod_all = [&a]() { return prod(a, false); };
|
||||
TIME(prod_all);
|
||||
|
||||
auto all_true = [&a]() { return mx::all(a, false); };
|
||||
auto all_true = [&a]() { return all(a, false); };
|
||||
TIME(all_true);
|
||||
|
||||
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||
TIME(all_along_0);
|
||||
|
||||
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||
TIME(all_along_1);
|
||||
|
||||
auto any_true = [&a]() { return mx::any(a, false); };
|
||||
auto any_true = [&a]() { return any(a, false); };
|
||||
TIME(any_true);
|
||||
|
||||
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||
TIME(argmin_along_0);
|
||||
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
auto a = mx::random::normal({1000, 768});
|
||||
mx::eval(a);
|
||||
auto indices = mx::random::randint(0, 1000, {256});
|
||||
mx::eval(indices);
|
||||
auto a = random::normal({1000, 768});
|
||||
eval(a);
|
||||
auto indices = random::randint(0, 1000, {256});
|
||||
eval(indices);
|
||||
|
||||
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||
TIME(embedding_lookup);
|
||||
|
||||
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
||||
mx::eval(indices);
|
||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||
eval(indices);
|
||||
|
||||
auto single_element_lookup = [&a, &indices]() {
|
||||
return mx::take(a, indices);
|
||||
};
|
||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||
TIME(single_element_lookup);
|
||||
|
||||
indices = mx::random::randint(0, 1000, {256});
|
||||
auto updates = mx::random::normal({256, 1, 768});
|
||||
mx::eval(indices, updates);
|
||||
indices = random::randint(0, 1000, {256});
|
||||
auto updates = random::normal({256, 1, 768});
|
||||
eval(indices, updates);
|
||||
|
||||
auto embedding_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -225,10 +223,10 @@ void time_gather_scatter() {
|
||||
};
|
||||
TIME(embedding_add);
|
||||
|
||||
a = mx::reshape(a, {-1});
|
||||
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = mx::random::normal({256 * 768, 1});
|
||||
mx::eval(a, indices, updates);
|
||||
a = reshape(a, {-1});
|
||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = random::normal({256 * 768, 1});
|
||||
eval(a, indices, updates);
|
||||
|
||||
auto single_element_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -242,21 +240,21 @@ void time_gather_scatter() {
|
||||
}
|
||||
|
||||
void time_divmod() {
|
||||
auto a = mx::random::normal({1000});
|
||||
auto b = mx::random::normal({1000});
|
||||
mx::eval({a, b});
|
||||
auto a = random::normal({1000});
|
||||
auto b = random::normal({1000});
|
||||
eval({a, b});
|
||||
|
||||
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
||||
TIME(divmod_fused);
|
||||
|
||||
auto divmod_separate = [&a, &b]() {
|
||||
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
||||
};
|
||||
TIME(divmod_separate);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
time_type_conversions();
|
||||
time_unary_ops();
|
||||
|
@@ -1,74 +0,0 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
B = 1024
|
||||
D = 1024
|
||||
M = 4 * D
|
||||
group_size = 64
|
||||
bits = 4
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def qmm_(x, wq1, wq2, q_type):
|
||||
for i in range(loops):
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
*wq1,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
*wq2,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def affine_qmm(x, wq1, wq2):
|
||||
return qmm_(x, wq1, wq2, "affine")
|
||||
|
||||
|
||||
def affine_packed_qmm(x, wq1, wq2):
|
||||
return qmm_(x, wq1, wq2, "affine-packed")
|
||||
|
||||
|
||||
def time_qmm():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_qmm, x, wq1, wq2)
|
||||
|
||||
|
||||
def time_packed_qmm():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(
|
||||
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(
|
||||
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_packed_qmm, x, wq1, wq2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for b in [2, 4, 8]:
|
||||
bits = b
|
||||
print(f"Bits {bits}:")
|
||||
time_qmm()
|
||||
time_packed_qmm()
|
@@ -1,189 +1,62 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from time_utils import time_fn
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 5
|
||||
N_iter_bench = 40
|
||||
N_iter_func = 8
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
|
||||
|
||||
def bench(f, *args):
|
||||
for i in range(N_warmup):
|
||||
f(*args)
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(*args)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
|
||||
|
||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
|
||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
n_kv_heads = k.shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
k = mx.expand_dims(k, 2)
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
if f32softmax:
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
||||
else:
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
||||
shape_q = (
|
||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
||||
)
|
||||
shape_kv = (
|
||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
||||
)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
|
||||
scale = math.sqrt(1.0 / head_dim)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
|
||||
if transpose:
|
||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
||||
|
||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
dtypes = ("float16", "float32")[:1]
|
||||
transposes = (False,)
|
||||
|
||||
# fmt: off
|
||||
shapes_64 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 32, 32, 64, 32, 32),
|
||||
( 1, 64, 64, 64, 32, 32),
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 32),
|
||||
( 1, 2048, 2048, 64, 32, 32),
|
||||
( 1, 4096, 4096, 64, 32, 32),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 32),
|
||||
( 1, 2048, 2048, 80, 32, 32),
|
||||
( 1, 4096, 4096, 80, 32, 32),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 32),
|
||||
( 1, 2048, 2048, 128, 32, 32),
|
||||
( 1, 4096, 4096, 128, 32, 32),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
|
@@ -4,51 +4,42 @@ import math
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 16384
|
||||
L = 1024
|
||||
H = 32
|
||||
H_k = H // 4
|
||||
H_k = 32 // 4
|
||||
D = 128
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
|
||||
for i in range(loops):
|
||||
q = _sdpa(q, k, v)
|
||||
return q
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
return q
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
@@ -420,8 +420,8 @@ element in the output.
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
// Convert linear indices to offsets in array
|
||||
@@ -438,10 +438,24 @@ each instantiation a unique host name so we can identify it.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
instantiate_kernel("axpby_general_float32", axpby_general, float)
|
||||
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
|
||||
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
|
||||
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
@@ -480,7 +494,7 @@ below.
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -495,14 +509,14 @@ below.
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
@@ -516,7 +530,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
@@ -209,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
Metal kernel cache persists accross reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
@@ -12,4 +12,5 @@ Fast
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
affine_quantize
|
||||
metal_kernel
|
||||
|
@@ -12,7 +12,6 @@ Layers
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
AvgPool3d
|
||||
BatchNorm
|
||||
CELU
|
||||
Conv1d
|
||||
@@ -42,7 +41,6 @@ Layers
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
MaxPool3d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
|
@@ -168,7 +168,6 @@ Operations
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
unflatten
|
||||
var
|
||||
view
|
||||
where
|
||||
|
@@ -184,8 +184,8 @@ Let's time these two different versions:
|
||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||
|
||||
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
||||
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
||||
|
||||
Of course, this operation is quite contrived. A better approach is to simply do
|
||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||
|
@@ -4,19 +4,19 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!mx::distributed::is_available()) {
|
||||
if (!distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = mx::distributed::init();
|
||||
auto global_group = distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
mx::array x = mx::ones({10});
|
||||
mx::array out = mx::distributed::all_sum(x, global_group);
|
||||
array x = ones({10});
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of linear regression with MLX.
|
||||
*/
|
||||
namespace mx = mlx::core;
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.01;
|
||||
|
||||
// True parameters
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
auto w_star = random::normal({num_features});
|
||||
|
||||
// The input examples (design matrix)
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
|
||||
// Noisy labels
|
||||
auto eps = 1e-2 * mx::random::normal({num_examples});
|
||||
auto y = mx::matmul(X, w_star) + eps;
|
||||
auto eps = 1e-2 * random::normal({num_examples});
|
||||
auto y = matmul(X, w_star) + eps;
|
||||
|
||||
// Initialize random parameters
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto yhat = mx::matmul(X, w);
|
||||
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
|
||||
auto loss_fn = [&](array w) {
|
||||
auto yhat = matmul(X, w);
|
||||
return (0.5f / num_examples) * sum(square(yhat - y));
|
||||
};
|
||||
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
auto grad_fn = grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
|
||||
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of logistic regression with MLX.
|
||||
*/
|
||||
namespace mx = mlx::core;
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.1;
|
||||
|
||||
// True parameters
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
auto w_star = random::normal({num_features});
|
||||
|
||||
// The input examples
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
|
||||
// Labels
|
||||
auto y = mx::matmul(X, w_star) > 0;
|
||||
auto y = matmul(X, w_star) > 0;
|
||||
|
||||
// Initialize random parameters
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto logits = mx::matmul(X, w);
|
||||
auto loss_fn = [&](array w) {
|
||||
auto logits = matmul(X, w);
|
||||
auto scale = (1.0f / num_examples);
|
||||
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
|
||||
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
|
||||
};
|
||||
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
auto grad_fn = grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
|
||||
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
||||
<< throughput << " (it/s)." << std::endl;
|
||||
|
@@ -5,27 +5,27 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
// To use Metal debugging and profiling:
|
||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||
mx::metal::start_capture("mlx_trace.gputrace");
|
||||
metal::start_capture("mlx_trace.gputrace");
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(mx::Device::gpu);
|
||||
auto s3 = new_stream(mx::Device::gpu);
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
|
||||
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
|
||||
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
|
||||
auto x = mx::add(a, a, s2);
|
||||
auto y = mx::add(b, b, s3);
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << mx::multiply(x, y) << std::endl;
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
|
||||
mx::metal::stop_capture();
|
||||
metal::stop_capture();
|
||||
}
|
||||
|
@@ -5,11 +5,11 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
using namespace mlx::core;
|
||||
|
||||
void array_basics() {
|
||||
// Make a scalar array:
|
||||
mx::array x(1.0);
|
||||
array x(1.0);
|
||||
|
||||
// Get the value out of it:
|
||||
auto s = x.item<float>();
|
||||
@@ -29,31 +29,31 @@ void array_basics() {
|
||||
|
||||
// The datatype should be float32:
|
||||
auto dtype = x.dtype();
|
||||
assert(dtype == mx::float32);
|
||||
assert(dtype == float32);
|
||||
|
||||
// Specify the dtype when constructing the array:
|
||||
x = mx::array(1, mx::int32);
|
||||
assert(x.dtype() == mx::int32);
|
||||
x = array(1, int32);
|
||||
assert(x.dtype() == int32);
|
||||
x.item<int>(); // OK
|
||||
// x.item<float>(); // Undefined!
|
||||
|
||||
// Make a multidimensional array:
|
||||
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
// mlx is row-major by default so the first row of this array
|
||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||
|
||||
// Make an array of shape {2, 2} filled with ones:
|
||||
auto y = mx::ones({2, 2});
|
||||
auto y = ones({2, 2});
|
||||
|
||||
// Pointwise add x and y:
|
||||
auto z = mx::add(x, y);
|
||||
auto z = add(x, y);
|
||||
|
||||
// Same thing:
|
||||
z = x + y;
|
||||
|
||||
// mlx is lazy by default. At this point `z` only
|
||||
// has a shape and a type but no actual data:
|
||||
assert(z.dtype() == mx::float32);
|
||||
assert(z.dtype() == float32);
|
||||
assert(z.shape(0) == 2);
|
||||
assert(z.shape(1) == 2);
|
||||
|
||||
@@ -63,33 +63,33 @@ void array_basics() {
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
// all of its dependencies are recursively evaluated to produce the result.
|
||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||
mx::eval(z);
|
||||
eval(z);
|
||||
|
||||
// Of course the array can still be an input to other operations. You can
|
||||
// even call eval on the array again, this will just be a no-op:
|
||||
mx::eval(z); // no-op
|
||||
// Of course the array can still be an input to other operations. You can even
|
||||
// call eval on the array again, this will just be a no-op:
|
||||
eval(z); // no-op
|
||||
|
||||
// Some functions or methods on arrays implicitly evaluate them. For example
|
||||
// accessing a value in an array or printing the array implicitly evaluate it:
|
||||
z = mx::ones({1});
|
||||
z = ones({1});
|
||||
z.item<float>(); // implicit evaluation
|
||||
|
||||
z = mx::ones({2, 2});
|
||||
z = ones({2, 2});
|
||||
std::cout << z << std::endl; // implicit evaluation
|
||||
}
|
||||
|
||||
void automatic_differentiation() {
|
||||
auto fn = [](mx::array x) { return mx::square(x); };
|
||||
auto fn = [](array x) { return square(x); };
|
||||
|
||||
// Computing the derivative function of a function
|
||||
auto grad_fn = mx::grad(fn);
|
||||
auto grad_fn = grad(fn);
|
||||
// Call grad_fn on the input to get the derivative
|
||||
auto x = mx::array(1.5);
|
||||
auto x = array(1.5);
|
||||
auto dfdx = grad_fn(x);
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
|
||||
auto d2fdx2 = grad(grad(fn))(x);
|
||||
// d2fdx2 is 2
|
||||
}
|
||||
|
||||
|
@@ -19,7 +19,7 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#endif
|
||||
|
||||
namespace my_ext {
|
||||
namespace mlx::core {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
@@ -32,24 +32,24 @@ namespace my_ext {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input mx::array x
|
||||
const mx::array& y, // Input mx::array y
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, mx::float32);
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = mx::astype(x, out_dtype, s);
|
||||
auto y_casted = mx::astype(y, out_dtype, s);
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
@@ -57,12 +57,12 @@ mx::array axpby(
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return mx::array(
|
||||
return array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* mx::Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<mx::Primitive> primitive = */
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -71,16 +71,16 @@ mx::array axpby(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
@@ -94,8 +94,8 @@ void axpby_impl(
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
@@ -105,8 +105,8 @@ void axpby_impl(
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
@@ -114,14 +114,14 @@ void Axpby::eval(
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == mx::float32) {
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
@@ -136,9 +136,9 @@ void Axpby::eval(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
@@ -150,10 +150,10 @@ void axpby_impl_accelerate(
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, mx::CopyType::Vector);
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
@@ -175,15 +175,15 @@ void axpby_impl_accelerate(
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == mx::float32 &&
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
@@ -198,8 +198,8 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
@@ -213,8 +213,8 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
@@ -225,7 +225,7 @@ void Axpby::eval_gpu(
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = mx::metal::device(s.device);
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Prepare to specialize based on contiguity
|
||||
bool contiguous_kernel =
|
||||
@@ -235,12 +235,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -272,15 +272,15 @@ void Axpby::eval_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
|
||||
// Encode shape, strides and ndim if needed
|
||||
if (!contiguous_kernel) {
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
}
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
@@ -295,15 +295,15 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
|
||||
/** Fail evaluation on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& out) {
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||
}
|
||||
|
||||
@@ -314,9 +314,9 @@ void Axpby::eval_gpu(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<mx::array> Axpby::jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
@@ -328,8 +328,8 @@ std::vector<mx::array> Axpby::jvp(
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = mx::array(scale, tangents[0].dtype());
|
||||
return {mx::multiply(scale_arr, tangents[0], stream())};
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
@@ -339,24 +339,24 @@ std::vector<mx::array> Axpby::jvp(
|
||||
}
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<mx::array> Axpby::vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<mx::array>&) {
|
||||
const std::vector<array>&) {
|
||||
// Reverse mode diff
|
||||
std::vector<mx::array> vjps;
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = mx::array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
}
|
||||
@@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace my_ext
|
||||
} // namespace mlx::core
|
||||
|
@@ -5,9 +5,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
namespace my_ext {
|
||||
namespace mlx::core {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation
|
||||
@@ -20,22 +18,22 @@ namespace my_ext {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input array x
|
||||
const mx::array& y, // Input array y
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class Axpby : public mx::Primitive {
|
||||
class Axpby : public Primitive {
|
||||
public:
|
||||
explicit Axpby(mx::Stream stream, float alpha, float beta)
|
||||
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
@@ -44,25 +42,23 @@ class Axpby : public mx::Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<mx::array> jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<mx::array> vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<mx::array>& outputs) override;
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@@ -70,8 +66,8 @@ class Axpby : public mx::Primitive {
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
@@ -80,16 +76,14 @@ class Axpby : public mx::Primitive {
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const mx::Primitive& other) const override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs);
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace my_ext
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T>
|
||||
@@ -12,8 +13,8 @@ template <typename T>
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
@@ -34,14 +35,29 @@ template <typename T>
|
||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \
|
||||
instantiate_kernel( \
|
||||
"axpby_contiguous_" #type_name, axpby_contiguous, type)
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||
axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||
axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
// clang-format on
|
||||
instantiate_axpby(complex64, complex64_t);
|
@@ -8,12 +8,14 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&my_ext::axpby,
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
"alpha"_a,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.21.0
|
||||
mlx>=0.18.1
|
||||
nanobind==2.2.0
|
||||
|
15
mlx.pc.in
15
mlx.pc.in
@@ -28,19 +28,10 @@ endif()
|
||||
if (@MLX_BUILD_METAL@)
|
||||
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
||||
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
set_and_check(MLX_INCLUDE_DIRS
|
||||
${MLX_INCLUDE_DIRS}
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
||||
)
|
||||
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set_target_properties(mlx PROPERTIES
|
||||
@@ -49,4 +40,4 @@ set_target_properties(mlx PROPERTIES
|
||||
)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
@@ -18,16 +18,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
if(MSVC)
|
||||
# Disable some MSVC warnings to speed up compilation.
|
||||
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
# Export symbols by default to behave like macOS/linux.
|
||||
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
else()
|
||||
|
@@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
|
||||
}
|
||||
|
||||
void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
return allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -31,7 +30,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
}
|
||||
|
||||
array::array(
|
||||
Shape shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
@@ -42,7 +41,7 @@ array::array(
|
||||
std::move(inputs))) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
std::vector<Shape> shapes,
|
||||
std::vector<std::vector<int>> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
@@ -74,7 +73,11 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, deleter);
|
||||
}
|
||||
@@ -122,7 +125,7 @@ bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = size();
|
||||
@@ -135,9 +138,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
void array::set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
Strides strides,
|
||||
std::vector<size_t> strides,
|
||||
Flags flags,
|
||||
Deleter d) {
|
||||
deleter_t d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = data_size;
|
||||
@@ -147,7 +150,7 @@ void array::set_data(
|
||||
|
||||
void array::copy_shared_buffer(
|
||||
const array& other,
|
||||
const Strides& strides,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@@ -166,7 +169,7 @@ void array::copy_shared_buffer(const array& other) {
|
||||
|
||||
void array::move_shared_buffer(
|
||||
array other,
|
||||
const Strides& strides,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@@ -211,8 +214,6 @@ array::~array() {
|
||||
if (do_detach) {
|
||||
for (auto& s : siblings()) {
|
||||
for (auto& ss : s.siblings()) {
|
||||
// Set to null here to avoid descending into array destructor
|
||||
// for siblings
|
||||
ss.array_desc_ = nullptr;
|
||||
}
|
||||
s.array_desc_->siblings.clear();
|
||||
@@ -233,13 +234,13 @@ void array::ArrayDesc::init() {
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
Shape shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
@@ -291,14 +292,6 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
append_deletable_inputs(*top);
|
||||
|
||||
// Clear out possible siblings to break circular references
|
||||
for (auto& s : top->siblings) {
|
||||
// Set to null here to avoid descending into top-level
|
||||
// array destructor for siblings
|
||||
s.array_desc_ = nullptr;
|
||||
}
|
||||
top->siblings.clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
51
mlx/array.h
51
mlx/array.h
@@ -15,10 +15,7 @@ namespace mlx::core {
|
||||
|
||||
// Forward declaration
|
||||
class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using Shape = std::vector<int32_t>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
using deleter_t = std::function<void(allocator::Buffer)>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
@@ -36,7 +33,7 @@ class array {
|
||||
template <typename It>
|
||||
array(
|
||||
It data,
|
||||
Shape shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
@@ -52,15 +49,15 @@ class array {
|
||||
template <typename T>
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
Shape shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
allocator::Buffer data,
|
||||
Shape shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
Deleter deleter = allocator::free);
|
||||
deleter_t deleter = allocator::free);
|
||||
|
||||
/** Assignment to rvalue does not compile. */
|
||||
array& operator=(const array& other) && = delete;
|
||||
@@ -99,7 +96,7 @@ class array {
|
||||
}
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const Shape& shape() const {
|
||||
const std::vector<int>& shape() const {
|
||||
return array_desc_->shape;
|
||||
}
|
||||
|
||||
@@ -108,12 +105,12 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
auto shape(int dim) const {
|
||||
int shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
/** The strides of the array. */
|
||||
const Strides& strides() const {
|
||||
const std::vector<size_t>& strides() const {
|
||||
return array_desc_->strides;
|
||||
}
|
||||
|
||||
@@ -122,7 +119,7 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
auto strides(int dim) const {
|
||||
size_t strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
@@ -187,13 +184,13 @@ class array {
|
||||
*/
|
||||
|
||||
array(
|
||||
Shape shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
std::vector<Shape> shapes,
|
||||
std::vector<std::vector<int>> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
@@ -210,8 +207,8 @@ class array {
|
||||
|
||||
struct Data {
|
||||
allocator::Buffer buffer;
|
||||
Deleter d;
|
||||
Data(allocator::Buffer buffer, Deleter d = allocator::free)
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
: buffer(buffer), d(d) {}
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
@@ -400,18 +397,18 @@ class array {
|
||||
// Check if the array is a tracer array
|
||||
bool is_tracer() const;
|
||||
|
||||
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
|
||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||
|
||||
void set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
Strides strides,
|
||||
std::vector<size_t> strides,
|
||||
Flags flags,
|
||||
Deleter d = allocator::free);
|
||||
deleter_t d = allocator::free);
|
||||
|
||||
void copy_shared_buffer(
|
||||
const array& other,
|
||||
const Strides& strides,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
@@ -420,7 +417,7 @@ class array {
|
||||
|
||||
void move_shared_buffer(
|
||||
array other,
|
||||
const Strides& strides,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
@@ -439,8 +436,8 @@ class array {
|
||||
void init(const It src);
|
||||
|
||||
struct ArrayDesc {
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::shared_ptr<Primitive> primitive;
|
||||
@@ -474,10 +471,10 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(Shape shape, Dtype dtype);
|
||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
Shape shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
@@ -505,7 +502,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
template <typename It>
|
||||
array::array(
|
||||
It data,
|
||||
Shape shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
init(data);
|
||||
@@ -524,7 +521,7 @@ array::array(
|
||||
template <typename T>
|
||||
array::array(
|
||||
std::initializer_list<T> data,
|
||||
Shape shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
|
@@ -43,7 +43,6 @@ DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
@@ -66,6 +65,7 @@ DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
@@ -76,7 +76,6 @@ DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
|
@@ -5,21 +5,13 @@ else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
set(SHELL_EXT ps1)
|
||||
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
|
||||
else()
|
||||
set(SHELL_EXT sh)
|
||||
set(SHELL_CMD /bin/bash)
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
|
@@ -13,8 +13,8 @@ template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
std::vector<size_t> strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
|
@@ -178,10 +178,10 @@ void binary_op_dims(
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@@ -212,10 +212,10 @@ void binary_op_dispatch_dims(
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides) {
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
@@ -258,10 +258,10 @@ void binary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
||||
auto stride = out_strides[dim - 4];
|
||||
for (int64_t elem = 0; elem < a.size(); elem += stride) {
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
|
||||
size_t stride = out_strides[dim - 4];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
@@ -327,7 +327,7 @@ void binary_op(
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
@@ -337,7 +337,7 @@ void binary_op(
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
|
@@ -16,10 +16,10 @@ void binary_op_dims(
|
||||
U* out_a,
|
||||
U* out_b,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@@ -96,9 +96,9 @@ void binary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
||||
size_t stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
|
@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
@@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
move_or_copy(in, out, strides, flags, in.data_size());
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval(
|
||||
@@ -72,7 +72,7 @@ void CustomTransforms::eval(
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
move_or_copy(inputs[j], outputs[i]);
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,20 +81,10 @@ void Depends::eval(
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
move_or_copy(inputs[i], outputs[i]);
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto strides = in.strides();
|
||||
for (auto ax : axes_) {
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
@@ -151,7 +141,9 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
@@ -159,7 +151,8 @@ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
return {false, Strides(out.ndim(), 0)};
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
@@ -167,7 +160,7 @@ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
Strides out_strides;
|
||||
std::vector<size_t> out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
@@ -188,9 +181,9 @@ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void shared_buffer_reshape(
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
@@ -201,7 +194,7 @@ void shared_buffer_reshape(
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
@@ -256,28 +249,26 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
Strides strides;
|
||||
for (int i = 0, j = 0; i < in.ndim(); ++i) {
|
||||
if (j < axes_.size() && i == axes_[j]) {
|
||||
j++;
|
||||
} else {
|
||||
strides.push_back(in.strides(i));
|
||||
}
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
Strides out_strides(out.ndim());
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
@@ -294,8 +285,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
@@ -306,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -165,7 +165,7 @@ void compiled_allocate_outputs(
|
||||
bool move_buffers /* = false */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
Strides strides;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
|
@@ -47,7 +47,7 @@ bool compile_available_for_device(const Device& device) {
|
||||
} // namespace detail
|
||||
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name).string();
|
||||
return std::filesystem::temp_directory_path().append(name);
|
||||
}
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
@@ -279,7 +279,7 @@ void Compiled::eval_cpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
|
@@ -746,9 +746,9 @@ void explicit_gemm_conv_1D_cpu(
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
Shape strided_shape = {N, oH, wH, C};
|
||||
std::vector<int> strided_shape = {N, oH, wH, C};
|
||||
|
||||
Strides strided_strides = {
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[1],
|
||||
@@ -865,9 +865,9 @@ void explicit_gemm_conv_2D_cpu(
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
||||
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
|
||||
|
||||
Strides strided_strides = {
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[2] * wt_strides[1],
|
||||
@@ -974,7 +974,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
Shape strided_shape(oDim.size() + wDim.size() + 2);
|
||||
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
@@ -984,7 +984,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
Strides strided_strides(in.shape().size() * 2 - 2);
|
||||
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
@@ -1000,7 +1000,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
Shape strided_reshape = {N, C};
|
||||
std::vector<int> strided_reshape = {N, C};
|
||||
for (const auto& o : oDim) {
|
||||
strided_reshape[0] *= o;
|
||||
}
|
||||
|
@@ -26,13 +26,13 @@ void copy_vector(const array& src, array& dst) {
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
template <typename SrcT, typename DstT, typename StrideT, int D>
|
||||
inline void copy_dims(
|
||||
const SrcT* src,
|
||||
DstT* dst,
|
||||
const Shape& shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
int axis) {
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
@@ -40,7 +40,7 @@ inline void copy_dims(
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
copy_dims<SrcT, DstT, D - 1>(
|
||||
copy_dims<SrcT, DstT, StrideT, D - 1>(
|
||||
src, dst, shape, i_strides, o_strides, axis + 1);
|
||||
} else {
|
||||
*dst = static_cast<DstT>(*src);
|
||||
@@ -50,13 +50,13 @@ inline void copy_dims(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
template <typename SrcT, typename DstT, typename StrideT>
|
||||
void copy_general_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
if (data_shape.empty()) {
|
||||
@@ -65,30 +65,30 @@ void copy_general_general(
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
auto [shape, strides] =
|
||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
int ndim = shape.size();
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
copy_dims<SrcT, DstT, StrideT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
copy_dims<SrcT, DstT, StrideT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
}
|
||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||
auto stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (int64_t elem = 0; elem < src.size(); elem += stride) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
|
||||
StrideT stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
|
||||
for (StrideT elem = 0; elem < src.size(); elem += stride) {
|
||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
shape,
|
||||
@@ -102,37 +102,37 @@ void copy_general_general(
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
copy_general_general<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
template <typename SrcT, typename DstT, typename StrideT>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides&,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>&,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
copy_general_general<SrcT, DstT, StrideT>(
|
||||
src,
|
||||
dst,
|
||||
data_shape,
|
||||
i_strides,
|
||||
make_contiguous_strides(data_shape),
|
||||
make_contiguous_strides<StrideT>(data_shape),
|
||||
i_offset,
|
||||
o_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
copy_general_general<SrcT, DstT, size_t>(
|
||||
src,
|
||||
dst,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
make_contiguous_strides(src.shape()),
|
||||
make_contiguous_strides<size_t>(src.shape()),
|
||||
0,
|
||||
0);
|
||||
}
|
||||
@@ -282,12 +282,13 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
copy_inplace(src, dst, ctype);
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
@@ -310,4 +311,24 @@ void copy_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
template void copy_inplace<size_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<size_t>& i_strides,
|
||||
const std::vector<size_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
template void copy_inplace<int64_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<int64_t>& i_strides,
|
||||
const std::vector<int64_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -26,12 +26,13 @@ enum class CopyType {
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
@@ -57,7 +57,6 @@ DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
@@ -87,6 +86,7 @@ DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
@@ -101,7 +101,6 @@ DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
@@ -131,7 +130,7 @@ inline void matmul_common_general(
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
stx = arr.shape(-1);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
@@ -32,7 +32,7 @@ void gather(
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& slice_sizes) {
|
||||
const std::vector<int>& slice_sizes) {
|
||||
// If the array is row contiguous then we can do a contiguous copy given
|
||||
// two conditions on the slice size:
|
||||
// - Any number of leading ones in the slice sizes are allowed
|
||||
@@ -80,10 +80,11 @@ void gather(
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator src_it;
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
src_it = std::move(
|
||||
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
|
||||
}
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
@@ -118,7 +119,7 @@ void dispatch_gather(
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& size) {
|
||||
const std::vector<int>& size) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather<bool, IdxT>(src, inds, out, axes, size);
|
||||
@@ -222,16 +223,16 @@ void scatter(
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
|
||||
Shape update_shape(
|
||||
std::vector<int> update_shape(
|
||||
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||
size_t update_size = 1;
|
||||
for (auto us : update_shape) {
|
||||
update_size *= us;
|
||||
}
|
||||
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator update_it(updates);
|
||||
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> update_it(updates);
|
||||
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
|
@@ -2,15 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Required for Visual Studio.
|
||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
||||
#ifdef _MSC_VER
|
||||
#include <complex>
|
||||
#define LAPACK_COMPLEX_CUSTOM
|
||||
#define lapack_complex_float std::complex<float>
|
||||
#define lapack_complex_double std::complex<double>
|
||||
#endif
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
|
@@ -1,38 +0,0 @@
|
||||
# This script generates a C++ function that provides the CPU
|
||||
# code for use with kernel generation.
|
||||
#
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
$OUTPUT_FILE = $args[0]
|
||||
$CL = $args[1]
|
||||
$SRCDIR = $args[2]
|
||||
|
||||
# Get command result as array.
|
||||
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
|
||||
# Remove empty lines.
|
||||
# Otherwise there will be too much empty lines making the result unreadable.
|
||||
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
|
||||
# Concatenate to string.
|
||||
$CONTENT = $CONTENT -join '`n'
|
||||
|
||||
# Append extra content.
|
||||
$CONTENT = @"
|
||||
$($CONTENT)
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::detail;
|
||||
"@
|
||||
|
||||
# Convert each char to ASCII code.
|
||||
# Unlike the unix script that outputs string literal directly, the output from
|
||||
# MSVC is way too large to be embedded as string and compilation will fail, so
|
||||
# we store it as static array instead.
|
||||
$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'
|
||||
|
||||
$OUTPUT = @"
|
||||
const char* get_kernel_preamble() {
|
||||
static char preamble[] = { $CHARCODES };
|
||||
return preamble;
|
||||
}
|
||||
"@
|
||||
|
||||
Set-Content -Path $OUTPUT_FILE -Value $OUTPUT
|
@@ -10,16 +10,15 @@ OUTPUT_FILE=$1
|
||||
GCC=$2
|
||||
SRCDIR=$3
|
||||
CLANG=$4
|
||||
ARCH=$5
|
||||
|
||||
if [ "$CLANG" = "TRUE" ]; then
|
||||
read -r -d '' INCLUDES <<- EOM
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
EOM
|
||||
CC_FLAGS="-arch ${ARCH}"
|
||||
CC_FLAGS=""
|
||||
else
|
||||
CC_FLAGS="-std=c++17"
|
||||
fi
|
||||
|
@@ -19,10 +19,10 @@ inline void mask_matrix(
|
||||
int block_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const int64_t X_data_str,
|
||||
const int64_t Y_data_str,
|
||||
const int64_t X_mask_str,
|
||||
const int64_t Y_mask_str,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str,
|
||||
const size_t mask_offset) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
@@ -84,7 +84,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
int64_t stx = arr.shape(-1);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
@@ -117,13 +117,13 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
auto mask_offset = elem_to_loc(
|
||||
size_t mask_offset = elem_to_loc(
|
||||
mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
|
||||
auto X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
|
||||
if (mask.dtype() == bool_) {
|
||||
return mask_matrix(
|
||||
@@ -230,7 +230,7 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
int64_t stx = arr.shape(-1);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
@@ -262,13 +262,13 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
auto batch_shape = get_batch_dims(out.shape());
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
auto batch_shape_A = get_batch_dims(a.shape());
|
||||
auto batch_strides_A = get_batch_dims(a.strides());
|
||||
auto batch_shape_B = get_batch_dims(b.shape());
|
||||
auto batch_strides_B = get_batch_dims(b.strides());
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
@@ -500,12 +500,7 @@ struct Equal {
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
// isnan always returns false for integers, and MSVC refuses to compile.
|
||||
return x == y;
|
||||
} else {
|
||||
return x == y || (std::isnan(x) && std::isnan(y));
|
||||
}
|
||||
return x == y || (std::isnan(x) && std::isnan(y));
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -19,16 +19,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -169,17 +159,6 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -268,14 +247,6 @@ void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -435,8 +406,18 @@ void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -506,17 +487,34 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General);
|
||||
} else {
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
Strides ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -541,11 +539,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
auto [data_offset, out_strides] = prepare_slice(out);
|
||||
|
||||
// Do copy
|
||||
Strides upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_inplace(
|
||||
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
@@ -608,7 +606,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
for (int i = 0; i < strides.size() - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
|
@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
// Copy the input to be column contiguous
|
||||
flags.col_contiguous = num_matrices == 1;
|
||||
flags.row_contiguous = false;
|
||||
auto strides = in.strides();
|
||||
std::vector<size_t> strides = in.strides();
|
||||
strides[in.ndim() - 2] = 1;
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(
|
||||
|
@@ -2,38 +2,13 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, int bits>
|
||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
assert(bits == 3 || bits == 6);
|
||||
if (bits == 3) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
|
||||
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
|
||||
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
|
||||
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||
} else if (bits == 6) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||
w_out[1] =
|
||||
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
|
||||
w_out[2] =
|
||||
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
|
||||
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm(
|
||||
T* result,
|
||||
@@ -45,12 +20,13 @@ void _qmm(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Ng = N / group_size;
|
||||
const int Nw = N / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const uint32_t* w_local = w;
|
||||
const T* scales_local = scales;
|
||||
const T* biases_local = biases;
|
||||
|
||||
@@ -64,25 +40,13 @@ void _qmm(
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) += xi * (scale * wl[p] + bias);
|
||||
}
|
||||
w_local += bytes_per_pack;
|
||||
uint32_t wi = *w_local++;
|
||||
|
||||
} else {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) +=
|
||||
xi * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
if (bits != 8) {
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) +=
|
||||
xi * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,12 +67,13 @@ void _qmm_t(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Kg = K / group_size;
|
||||
const int Kw = K / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const uint32_t* w_local = w;
|
||||
const T* scales_local = scales;
|
||||
const T* biases_local = biases;
|
||||
|
||||
@@ -120,26 +85,12 @@ void _qmm_t(
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
sum += x_local[p] * (scale * wl[p] + bias);
|
||||
}
|
||||
w_local += bytes_per_pack;
|
||||
x_local += pack_factor;
|
||||
uint32_t wi = *w_local++;
|
||||
|
||||
} else {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
sum +=
|
||||
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
if (bits != 8) {
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -151,55 +102,6 @@ void _qmm_t(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_dispatch_transpose(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
bool transposed_w) {
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits>
|
||||
void _qmm_dispatch_group(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
_qmm_dispatch_transpose<T, bits, 32>(
|
||||
result, x, w, scales, biases, M, N, K, transposed_w);
|
||||
break;
|
||||
case 64:
|
||||
_qmm_dispatch_transpose<T, bits, 64>(
|
||||
result, x, w, scales, biases, M, N, K, transposed_w);
|
||||
break;
|
||||
case 128:
|
||||
_qmm_dispatch_transpose<T, bits, 128>(
|
||||
result, x, w, scales, biases, M, N, K, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"Quantization group size must be 32, 64 or 128.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _qmm_dispatch_typed(
|
||||
T* result,
|
||||
@@ -214,29 +116,79 @@ void _qmm_dispatch_typed(
|
||||
int bits,
|
||||
bool transposed_w) {
|
||||
switch (bits) {
|
||||
case 2:
|
||||
_qmm_dispatch_group<T, 2>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 3:
|
||||
_qmm_dispatch_group<T, 3>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 4:
|
||||
_qmm_dispatch_group<T, 4>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 6:
|
||||
_qmm_dispatch_group<T, 6>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 8:
|
||||
_qmm_dispatch_group<T, 8>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
|
||||
case 2: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 128:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
case 4: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 128:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
case 8: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 128:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::ostringstream msg;
|
||||
msg << "Quantization type not supported. Provided bits=" << bits
|
||||
<< " and group_size=" << group_size
|
||||
<< ". The supported options are bits in "
|
||||
<< "{2, 4, 8} and group_size in {64, 128}.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
@@ -452,114 +404,4 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
transpose_);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void quantize(
|
||||
const array& w_,
|
||||
array& out_,
|
||||
array& scales_,
|
||||
array& biases_,
|
||||
int bits,
|
||||
int group_size) {
|
||||
const T* w = w_.data<T>();
|
||||
|
||||
auto out = out_.data<U>();
|
||||
T* scales = scales_.data<T>();
|
||||
T* biases = biases_.data<T>();
|
||||
|
||||
T n_bins = (1 << bits) - 1;
|
||||
T eps = 1e-7;
|
||||
bool power_of_2_bits = is_power_of_2(bits);
|
||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||
size_t n_groups = w_.size() / group_size;
|
||||
|
||||
for (size_t i = 0; i < n_groups; ++i) {
|
||||
size_t w_idx = i * group_size;
|
||||
T w_min = std::numeric_limits<float>::infinity();
|
||||
T w_max = -w_min;
|
||||
for (int j = 0; j < group_size; ++j) {
|
||||
w_max = std::max(w_max, w[w_idx + j]);
|
||||
w_min = std::min(w_min, w[w_idx + j]);
|
||||
}
|
||||
bool mask = std::abs(w_min) > std::abs(w_max);
|
||||
T scale = std::max(T((w_max - w_min) / n_bins), eps);
|
||||
scale = mask ? scale : -scale;
|
||||
|
||||
auto edge = mask ? w_min : w_max;
|
||||
auto q0 = std::rint(edge / scale);
|
||||
if (q0 == 0) {
|
||||
scales[i] = scale;
|
||||
biases[i] = 0;
|
||||
} else {
|
||||
scales[i] = edge / q0;
|
||||
biases[i] = edge;
|
||||
}
|
||||
size_t out_idx = i * int_per_group;
|
||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||
uint32_t out_el = 0;
|
||||
for (int k = 0; k < el_per_int; ++k) {
|
||||
T w_el = w[w_idx + j * el_per_int + k];
|
||||
w_el = std::rint((w_el - biases[i]) / scales[i]);
|
||||
w_el = std::min(std::max(w_el, T(0)), n_bins);
|
||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||
}
|
||||
if (power_of_2_bits) {
|
||||
out[out_idx + j] = out_el;
|
||||
} else {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(inputs[0]);
|
||||
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -120,53 +120,45 @@ struct MinReduce {
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_and_or(
|
||||
void reduce_dispatch_out(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
switch (rtype) {
|
||||
case Reduce::And: {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, op);
|
||||
} else {
|
||||
case Reduce::Or: {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if (out.dtype() == int32) {
|
||||
// special case since the input type can be bool
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Prod: {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
break;
|
||||
}
|
||||
case Reduce::Max: {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Min: {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,19 +166,19 @@ void reduce_dispatch_min_max(
|
||||
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const Shape& shape,
|
||||
const Strides& strides) {
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
@@ -198,114 +190,46 @@ void nd_loop(
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
case uint8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
case uint16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -38,10 +38,13 @@ enum ReductionOpType {
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
@@ -52,10 +55,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
// Should this be in utils?
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const Shape& shape,
|
||||
const Strides& strides);
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides);
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
|
||||
@@ -110,6 +113,9 @@ void reduction_op(
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
@@ -129,7 +135,7 @@ void reduction_op(
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
@@ -175,7 +181,7 @@ void reduction_op(
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
@@ -205,7 +211,7 @@ void reduction_op(
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
|
@@ -4,11 +4,11 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
@@ -29,8 +29,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
Shape shape = {x.shape(axes[0])};
|
||||
Strides strides = {x.strides()[axes[0]]};
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
@@ -69,7 +69,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, int64_t>> reductions;
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
@@ -93,8 +93,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
}
|
||||
}
|
||||
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
@@ -109,15 +109,15 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int64_t size = 1;
|
||||
int size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto stride_i = x.strides()[i];
|
||||
auto shape_i = x.shape(i);
|
||||
size_t stride_i = x.strides()[i];
|
||||
int shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
|
@@ -4,22 +4,24 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<int64_t, Strides> prepare_slice(
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides) {
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides) {
|
||||
int64_t data_offset = 0;
|
||||
Strides inp_strides(in.ndim(), 0);
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
copy_needed |= strides[i] < 0;
|
||||
}
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
@@ -32,7 +34,7 @@ void shared_buffer_slice(
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
flags.contiguous = (no_bsx_size == data_size);
|
||||
|
||||
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -6,14 +6,14 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<int64_t, Strides> prepare_slice(
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides);
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out);
|
||||
|
@@ -25,7 +25,7 @@ struct StridedIterator {
|
||||
// Constructors
|
||||
StridedIterator() = default;
|
||||
|
||||
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
|
||||
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)
|
||||
: ptr_(ptr + offset * stride), stride_(stride) {}
|
||||
|
||||
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
|
||||
@@ -99,7 +99,7 @@ struct StridedIterator {
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t stride_;
|
||||
size_t stride_;
|
||||
T* ptr_;
|
||||
};
|
||||
|
||||
@@ -120,11 +120,11 @@ void sort(const array& in, array& out, int axis) {
|
||||
auto remaining_strides = out.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
auto axis_stride = out.strides()[axis];
|
||||
auto axis_size = out.shape(axis);
|
||||
size_t axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
|
||||
// Perform sorting in place
|
||||
ContiguousIterator src_it(
|
||||
ContiguousIterator<size_t> src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
@@ -158,14 +158,14 @@ void argsort(const array& in, array& out, int axis) {
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
auto in_stride = in.strides()[axis];
|
||||
auto out_stride = out.strides()[axis];
|
||||
auto axis_size = in.shape(axis);
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting
|
||||
ContiguousIterator in_it(
|
||||
ContiguousIterator<size_t> in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator out_it(
|
||||
ContiguousIterator<size_t> out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
@@ -208,13 +208,13 @@ void partition(const array& in, array& out, int axis, int kth) {
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
auto axis_stride = in.strides()[axis];
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
ContiguousIterator src_it(
|
||||
ContiguousIterator<size_t> src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
@@ -249,16 +249,16 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
auto in_stride = in.strides()[axis];
|
||||
auto out_stride = out.strides()[axis];
|
||||
auto axis_size = in.shape(axis);
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition
|
||||
ContiguousIterator in_it(
|
||||
ContiguousIterator<size_t> in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator out_it(
|
||||
ContiguousIterator<size_t> out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
|
@@ -78,11 +78,11 @@ void ternary_op_dims(
|
||||
const T3* c,
|
||||
U* out,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& c_strides,
|
||||
const Strides& out_strides,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& c_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@@ -164,10 +164,10 @@ void ternary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator c_it(shape, c_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
|
||||
size_t stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
|
@@ -4,35 +4,15 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void move_or_copy(const array& in, array& out) {
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in, strides, flags, data_size, offset);
|
||||
} else {
|
||||
out.copy_shared_buffer(in, strides, flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
int64_t size_cap) {
|
||||
template <typename StrideT>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
|
||||
collapse_contiguous_dims_impl(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<StrideT>>& strides,
|
||||
StrideT size_cap) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
Shape to_collapse;
|
||||
std::vector<int> to_collapse;
|
||||
if (shape.size() > 0) {
|
||||
if (shape[0] != 1) {
|
||||
to_collapse.push_back(0);
|
||||
@@ -41,7 +21,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
size *= shape[i];
|
||||
for (const auto& st : strides) {
|
||||
for (const std::vector<StrideT>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
|
||||
contiguous = false;
|
||||
size = shape[i];
|
||||
@@ -58,8 +38,8 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
Shape out_shape;
|
||||
std::vector<Strides> out_strides(strides.size());
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<StrideT>> out_strides(strides.size());
|
||||
for (int i = 0;;) {
|
||||
while (i < to_collapse.size() && to_collapse[i] == -1) {
|
||||
++i;
|
||||
@@ -74,7 +54,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const auto& st = strides[j];
|
||||
const std::vector<StrideT>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
||||
}
|
||||
i = k + 1;
|
||||
@@ -89,12 +69,29 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
int64_t size_cap) {
|
||||
Shape collapsed_shape;
|
||||
Strides collapsed_strides;
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<int64_t>>& strides,
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>>& strides,
|
||||
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
|
||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides,
|
||||
StrideT size_cap) {
|
||||
std::vector<int> collapsed_shape;
|
||||
std::vector<StrideT> collapsed_strides;
|
||||
|
||||
if (shape.size() > 0) {
|
||||
collapsed_shape.push_back(shape[0]);
|
||||
@@ -104,7 +101,7 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
continue;
|
||||
} else if (
|
||||
strides[i] * shape[i] != collapsed_strides.back() ||
|
||||
collapsed_shape.back() * static_cast<int64_t>(shape[i]) > size_cap) {
|
||||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
|
||||
collapsed_shape.push_back(shape[i]);
|
||||
collapsed_strides.push_back(strides[i]);
|
||||
} else {
|
||||
@@ -117,10 +114,25 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
return std::make_pair(collapsed_shape, collapsed_strides);
|
||||
}
|
||||
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<int64_t>& strides,
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
|
||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
|
||||
return collapse_contiguous_dims_impl<size_t>(
|
||||
a.shape(), a.strides(), size_cap);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -8,9 +8,12 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline int64_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
int64_t loc = 0;
|
||||
template <typename StrideT>
|
||||
inline StrideT elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides) {
|
||||
StrideT loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
loc += q_and_r.rem * strides[i];
|
||||
@@ -19,15 +22,16 @@ elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline int64_t elem_to_loc(int elem, const array& a) {
|
||||
inline size_t elem_to_loc(int elem, const array& a) {
|
||||
if (a.flags().row_contiguous) {
|
||||
return elem;
|
||||
}
|
||||
return elem_to_loc(elem, a.shape(), a.strides());
|
||||
}
|
||||
|
||||
inline Strides make_contiguous_strides(const Shape& shape) {
|
||||
Strides strides(shape.size(), 1);
|
||||
template <typename StrideT>
|
||||
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
std::vector<StrideT> strides(shape.size(), 1);
|
||||
for (int i = shape.size() - 1; i > 0; i--) {
|
||||
strides[i - 1] = strides[i] * shape[i];
|
||||
}
|
||||
@@ -40,15 +44,22 @@ inline Strides make_contiguous_strides(const Shape& shape) {
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<int64_t>>& strides,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>>& strides,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<array>& xs,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max()) {
|
||||
std::vector<Strides> strides;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
@@ -62,14 +73,19 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
}
|
||||
|
||||
// The single array version of the above.
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<int64_t>& strides,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
template <typename StrideT>
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
@@ -86,7 +102,7 @@ struct ContiguousIterator {
|
||||
loc += strides_[i];
|
||||
}
|
||||
|
||||
void seek(int64_t n) {
|
||||
void seek(StrideT n) {
|
||||
loc = 0;
|
||||
for (int i = shape_.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(n, shape_[i]);
|
||||
@@ -112,29 +128,32 @@ struct ContiguousIterator {
|
||||
}
|
||||
|
||||
explicit ContiguousIterator(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides,
|
||||
int dims)
|
||||
: shape_(shape.begin(), shape.begin() + dims),
|
||||
strides_(strides.begin(), strides.begin() + dims) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = Shape(shape_.size(), 0);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t loc{0};
|
||||
StrideT loc{0};
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
Strides strides_;
|
||||
Shape pos_;
|
||||
std::vector<int> shape_;
|
||||
std::vector<StrideT> strides_;
|
||||
std::vector<int> pos_;
|
||||
};
|
||||
|
||||
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
template <typename StrideT>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides) {
|
||||
size_t no_broadcast_data_size = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
bool is_col_contiguous = true;
|
||||
|
||||
@@ -159,19 +178,4 @@ inline bool is_donatable(const array& in, const array& out) {
|
||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||
}
|
||||
|
||||
void move_or_copy(const array& in, array& out);
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
|
||||
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
} // namespace mlx::core
|
||||
|
@@ -14,21 +14,14 @@ function(make_jit_source SRC_FILE)
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE}
|
||||
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
|
||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||
add_dependencies(mlx ${SRC_NAME})
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
||||
endfunction(make_jit_source)
|
||||
|
||||
make_jit_source(
|
||||
utils
|
||||
kernels/jit/bf16.h
|
||||
kernels/metal_3_0/bf16.h
|
||||
kernels/metal_3_1/bf16.h
|
||||
kernels/bf16_math.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h)
|
||||
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
|
@@ -30,7 +30,7 @@ BufferCache::BufferCache(MTL::Device* device)
|
||||
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
clear();
|
||||
}
|
||||
|
||||
@@ -155,13 +155,11 @@ MetalAllocator::MetalAllocator()
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, max_pool_size_);
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, block_limit_);
|
||||
relaxed_ = relaxed;
|
||||
gc_limit_ = std::min(
|
||||
@@ -171,7 +169,6 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, wired_limit_);
|
||||
residency_set_.resize(wired_limit_);
|
||||
return limit;
|
||||
@@ -208,7 +205,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache
|
||||
@@ -229,7 +226,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
// Maintain the cache below the requested limit
|
||||
if (get_cache_memory() >= max_pool_size_) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
@@ -240,15 +237,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
void MetalAllocator::clear_cache() {
|
||||
std::unique_lock lk(mutex_);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
if (buf == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
residency_set_.erase(buf);
|
||||
active_memory_ -= buf->length();
|
||||
@@ -256,7 +249,7 @@ void MetalAllocator::free(Buffer buffer) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lk.unlock();
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
buf->release();
|
||||
}
|
||||
}
|
||||
|
@@ -22,37 +22,37 @@ std::string get_kernel_name(
|
||||
BinaryOpType bopt,
|
||||
const std::string& op,
|
||||
const array& a,
|
||||
bool large,
|
||||
bool use_2d,
|
||||
int ndim,
|
||||
int work_per_thread) {
|
||||
std::string kname;
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname = "ss";
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname = (large ? "sv2" : "sv");
|
||||
kname << (use_2d ? "sv2" : "sv");
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname = (large ? "vs2" : "vs");
|
||||
kname << (use_2d ? "vs2" : "vs");
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname = (large ? "vv2" : "vv");
|
||||
kname << (use_2d ? "vv2" : "vv");
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname = "g";
|
||||
kname << "g";
|
||||
if (ndim <= 3) {
|
||||
kname += std::to_string(ndim);
|
||||
kname << ndim;
|
||||
} else {
|
||||
concatenate(kname, "n", std::to_string(work_per_thread));
|
||||
}
|
||||
if (large) {
|
||||
kname += "large";
|
||||
kname << "n";
|
||||
if (work_per_thread > 1) {
|
||||
kname << work_per_thread;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
concatenate(kname, "_", op, type_to_name(a));
|
||||
return kname;
|
||||
kname << "_" << op << type_to_name(a);
|
||||
return kname.str();
|
||||
}
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
@@ -75,32 +75,24 @@ void binary_op_gpu_inplace(
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
|
||||
} else {
|
||||
decltype(a.strides()) e{};
|
||||
return std::make_tuple(decltype(a.shape()){}, e, e, e);
|
||||
std::vector<size_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||
|
||||
bool large;
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread;
|
||||
if (bopt == BinaryOpType::General) {
|
||||
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out.size() > INT32_MAX;
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = 1;
|
||||
}
|
||||
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = outputs.size() == 2
|
||||
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
|
||||
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
@@ -125,15 +117,19 @@ void binary_op_gpu_inplace(
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
|
||||
if (ndim > 3) {
|
||||
compute_encoder.set_vector_bytes(shape, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
|
||||
compute_encoder.set_bytes<int>(ndim, arg_idx++);
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
}
|
||||
|
||||
if (thread_group_size != 1024) {
|
||||
@@ -141,7 +137,7 @@ void binary_op_gpu_inplace(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
@@ -149,9 +145,9 @@ void binary_op_gpu_inplace(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <fmt/format.h>
|
||||
#include <iostream> //TODO
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@@ -12,12 +11,12 @@
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int WORK_PER_THREAD = 4;
|
||||
|
||||
inline void build_kernel(
|
||||
std::string& os,
|
||||
std::ostream& os,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
@@ -42,8 +41,8 @@ inline void build_kernel(
|
||||
int cnt = 0;
|
||||
|
||||
// Start the kernel
|
||||
os += fmt::format(
|
||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||
os << "[[host_name(\"" << kernel_name << "\")]]\n"
|
||||
<< "[[kernel]] void " << kernel_name << "(\n";
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
@@ -55,61 +54,51 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
if (is_scalar(x) || contiguous) {
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]],\n";
|
||||
} else {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]],\n";
|
||||
}
|
||||
os += fmt::format(
|
||||
" device const {0}* {1} [[buffer({2})]],\n",
|
||||
get_type_string(x.dtype()),
|
||||
xname,
|
||||
cnt++);
|
||||
}
|
||||
|
||||
if (add_indices) {
|
||||
os += fmt::format(
|
||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
os << " constant const size_t* in_strides [[buffer(" << cnt++
|
||||
<< ")]],\n";
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os += fmt::format(
|
||||
" device {0}* {1} [[buffer({2})]],\n",
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x),
|
||||
cnt++);
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os += fmt::format(
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
||||
<< ")]],\n"
|
||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
|
||||
}
|
||||
|
||||
// The thread index in the whole grid
|
||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||
os << " uint3 pos [[thread_position_in_grid]],\n"
|
||||
<< " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
if (contiguous && use_big_index) {
|
||||
if (use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
" int xshape = output_shape[{0}];\n",
|
||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||
os += fmt::format(
|
||||
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
|
||||
<< " int xshape = output_shape["
|
||||
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
|
||||
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||
} else {
|
||||
os += fmt::format(
|
||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||
}
|
||||
|
||||
// Read constant / contiguous inputs in tmps
|
||||
@@ -120,19 +109,16 @@ inline void build_kernel(
|
||||
|
||||
if (is_constant(x)) {
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
os += fmt::format(
|
||||
" auto tmp_{0} = static_cast<{1}>({2});\n",
|
||||
xname,
|
||||
get_type_string(x.dtype()),
|
||||
ss.str());
|
||||
os << " auto tmp_" << xname << " = static_cast<"
|
||||
<< get_type_string(x.dtype()) << ">(";
|
||||
print_constant(os, x);
|
||||
os << ");\n";
|
||||
} else if (is_scalar(x)) {
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];\n";
|
||||
} else if (contiguous) {
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];\n";
|
||||
} else {
|
||||
nc_inputs.push_back(x);
|
||||
}
|
||||
@@ -141,96 +127,83 @@ inline void build_kernel(
|
||||
// Initialize the indices for non-contiguous inputs
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||
if (ndim == 1) {
|
||||
int offset = i * ndim;
|
||||
os +=
|
||||
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
|
||||
<< "in_strides[" << offset << "]);\n";
|
||||
} else if (ndim == 2) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
|
||||
<< "in_strides + " << offset << ");\n";
|
||||
} else if (ndim == 3) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
|
||||
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
|
||||
<< "in_strides + " << offset << ");\n";
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = (i + 1) * ndim;
|
||||
os += fmt::format(
|
||||
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
|
||||
idx_type,
|
||||
offset - 1,
|
||||
offset - 2);
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
|
||||
<< offset + ndim - 1 << "]"
|
||||
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
|
||||
} else {
|
||||
os += fmt::format(
|
||||
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
|
||||
idx_type,
|
||||
i);
|
||||
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
|
||||
<< i << " + ndim - 1]"
|
||||
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
|
||||
os += " uint zpos = pos.z;\n";
|
||||
os << " uint zpos = pos.z;\n";
|
||||
if (dynamic_dims) {
|
||||
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||
} else {
|
||||
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
|
||||
}
|
||||
os += " uint l = zpos % output_shape[d];\n";
|
||||
os << " uint l = zpos % output_shape[d];\n";
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" index_{0} += ", xname);
|
||||
os << " index_" << xname << " += ";
|
||||
if (dynamic_dims) {
|
||||
os +=
|
||||
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
|
||||
os << "l * in_strides[" << i << " * ndim + d];\n";
|
||||
} else {
|
||||
os +=
|
||||
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
|
||||
os << "l * in_strides[" << i * ndim << " + d];\n";
|
||||
}
|
||||
}
|
||||
os += " zpos /= output_shape[d];\n }\n";
|
||||
os << " zpos /= output_shape[d];\n }\n";
|
||||
}
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
os +=
|
||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read non-contiguous inputs into tmps
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index_" << xname << "];\n";
|
||||
}
|
||||
|
||||
// Actually write the computation
|
||||
for (auto& x : tape) {
|
||||
os += fmt::format(
|
||||
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
||||
<< " = ";
|
||||
if (is_static_cast(x.primitive())) {
|
||||
os += fmt::format(
|
||||
"static_cast<{0}>(tmp_{1});\n",
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x.inputs()[0]));
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");\n";
|
||||
} else {
|
||||
std::ostringstream ss;
|
||||
x.primitive().print(ss);
|
||||
os += ss.str();
|
||||
os += "()(";
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
}
|
||||
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Write the outputs from tmps
|
||||
for (auto& x : outputs) {
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
||||
<< ";\n";
|
||||
}
|
||||
// Increment indices and close per thread loop
|
||||
if (work_per_thread > 1) {
|
||||
@@ -238,18 +211,18 @@ inline void build_kernel(
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
if (!dynamic_dims) {
|
||||
os += fmt::format(
|
||||
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
|
||||
os << " index_" << xname << " += "
|
||||
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
|
||||
} else {
|
||||
os += fmt::format(
|
||||
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
|
||||
os << " index_" << xname << " += "
|
||||
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
|
||||
}
|
||||
}
|
||||
os += " index++;\n }\n";
|
||||
os << " index++;\n }\n";
|
||||
}
|
||||
|
||||
// Finish the kernel
|
||||
os += "}\n";
|
||||
os << "}\n";
|
||||
|
||||
if (cnt > 31) {
|
||||
std::ostringstream msg;
|
||||
@@ -273,9 +246,9 @@ void Compiled::eval_gpu(
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||
std::string kernel = metal::utils();
|
||||
concatenate(
|
||||
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
||||
std::ostringstream kernel;
|
||||
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
||||
<< metal::ternary_ops();
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous",
|
||||
@@ -288,7 +261,7 @@ void Compiled::eval_gpu(
|
||||
/* dynamic_dims = */ false);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
kernel_lib_ + "_contiguous_big",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
@@ -309,21 +282,7 @@ void Compiled::eval_gpu(
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ i > 3 ? 2 : 1);
|
||||
if (i > 1) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ i > 3 ? 4 : 1);
|
||||
}
|
||||
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
@@ -336,32 +295,20 @@ void Compiled::eval_gpu(
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ 2);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_dynamic_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ 4);
|
||||
return kernel;
|
||||
/* work_per_thread = */ WORK_PER_THREAD);
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
std::vector<Strides> initial_strides;
|
||||
std::vector<std::vector<size_t>> initial_strides;
|
||||
initial_strides.push_back(outputs[0].strides());
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::vector<int> shape;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
if (!contiguous) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
@@ -376,7 +323,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
||||
if (output_shape[j] == 1) {
|
||||
@@ -402,19 +349,13 @@ void Compiled::eval_gpu(
|
||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
||||
}
|
||||
|
||||
bool large;
|
||||
bool use_2d = false;
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
use_2d = (max_size > UINT32_MAX);
|
||||
}
|
||||
|
||||
// Get the kernel from the lib
|
||||
@@ -427,18 +368,17 @@ void Compiled::eval_gpu(
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "_large";
|
||||
} else if (use_2d) {
|
||||
kernel_name += "_big";
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
Strides in_strides;
|
||||
std::vector<size_t> in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
@@ -454,7 +394,8 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
if (!in_strides.empty()) {
|
||||
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
||||
compute_encoder->setBytes(
|
||||
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
@@ -467,13 +408,14 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
compute_encoder->setBytes(
|
||||
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
if (dynamic) {
|
||||
compute_encoder.set_bytes(ndim, cnt++);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
|
||||
}
|
||||
|
||||
// Launch the kernel
|
||||
@@ -482,15 +424,15 @@ void Compiled::eval_gpu(
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
|
||||
MTL::Size grid_dims = large
|
||||
MTL::Size grid_dims = use_2d
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
int pow2;
|
||||
@@ -503,7 +445,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -44,28 +44,27 @@ void explicit_gemm_conv_ND_gpu(
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
int tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
Shape wt_reshape{implicit_K, implicit_N};
|
||||
Strides wt_restride{1, implicit_K};
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
|
||||
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
|
||||
auto wt_flags = wt.flags();
|
||||
wt_flags.row_contiguous = false;
|
||||
@@ -123,31 +122,33 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
<< N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
int tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Transpose kernel weights so that we can slice them by contiguous chunks
|
||||
// of channel groups.
|
||||
array wt_view(
|
||||
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
|
||||
wt_view.copy_shared_buffer(
|
||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||
wt,
|
||||
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
|
||||
wt.flags(),
|
||||
wt.size());
|
||||
|
||||
// Materialize
|
||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
||||
@@ -236,7 +237,7 @@ void slow_conv_2D_gpu(
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
||||
|
||||
@@ -251,8 +252,8 @@ void slow_conv_2D_gpu(
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
@@ -351,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
wn,
|
||||
n_channel_specialization,
|
||||
small_filter);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
@@ -367,11 +368,11 @@ void implicit_gemm_conv_2D_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_general_gpu(
|
||||
@@ -505,7 +506,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
@@ -522,15 +523,17 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
compute_encoder.set_bytes(jump_params, 5);
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
|
||||
|
||||
compute_encoder.set_vector_bytes(base_h, 6);
|
||||
compute_encoder.set_vector_bytes(base_w, 7);
|
||||
compute_encoder->setBytes(
|
||||
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
@@ -619,18 +622,18 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
compute_encoder.set_output_array(filt_wg, 1);
|
||||
|
||||
compute_encoder.set_bytes(C_c, 2);
|
||||
compute_encoder.set_bytes(O_c, 3);
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
@@ -647,17 +650,18 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
@@ -694,17 +698,18 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -43,12 +43,13 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& data_shape,
|
||||
const Strides& strides_in_pre,
|
||||
const Strides& strides_out_pre,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& strides_in_pre,
|
||||
const std::vector<stride_t>& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
@@ -67,52 +68,50 @@ void copy_gpu_inplace(
|
||||
/* size_cap = */ INT32_MAX);
|
||||
return std::make_tuple(shape, strides[0], strides[1]);
|
||||
} else {
|
||||
Strides e{};
|
||||
return std::make_tuple(Shape{}, e, e);
|
||||
std::vector<stride_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
||||
int ndim = shape.size();
|
||||
bool large;
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
// Allow for negative strides
|
||||
large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
}
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
int work_per_thread = 1;
|
||||
std::string kernel_name;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kernel_name = (large ? "s2" : "s");
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kernel_name = (large ? "v2" : "v");
|
||||
break;
|
||||
case CopyType::General:
|
||||
kernel_name = "g";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kernel_name = "gg";
|
||||
break;
|
||||
}
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
} else {
|
||||
work_per_thread = large ? 4 : 2;
|
||||
concatenate(kernel_name, "n", std::to_string(work_per_thread));
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << (use_2d ? "s2" : "s");
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << (use_2d ? "v2" : "v");
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "gg";
|
||||
break;
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "large";
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
work_per_thread = 4;
|
||||
kname << "n4";
|
||||
}
|
||||
}
|
||||
kname << "_copy";
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
|
||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
@@ -123,26 +122,26 @@ void copy_gpu_inplace(
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
Strides strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
Strides strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
if (ndim > 3) {
|
||||
compute_encoder.set_vector_bytes(shape, ndim, 2);
|
||||
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
||||
}
|
||||
compute_encoder.set_vector_bytes(strides_in, ndim, 3);
|
||||
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
size_t rest = data_size / (dim0 * dim1);
|
||||
int rest = data_size / (dim0 * dim1);
|
||||
|
||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||
compute_encoder.set_bytes(ndim, 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
}
|
||||
|
||||
@@ -153,16 +152,16 @@ void copy_gpu_inplace(
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,13 +178,14 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& istride,
|
||||
const std::vector<int64_t>& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s);
|
||||
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
@@ -193,13 +193,13 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
|
||||
type_to_name(val) + type_to_name(out);
|
||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
@@ -210,9 +210,9 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -8,12 +8,13 @@
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
@@ -31,7 +32,7 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& istride,
|
||||
const std::vector<int64_t>& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
@@ -43,7 +43,7 @@ void CustomKernel::eval_gpu(
|
||||
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
@@ -53,15 +53,15 @@ void CustomKernel::eval_gpu(
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
|
||||
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
|
||||
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
compute_encoder.set_bytes(ndim, index);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
@@ -72,11 +72,10 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size group_dims =
|
||||
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
@@ -23,18 +23,14 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto get_metal_version() {
|
||||
auto get_metal_version_ = []() {
|
||||
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
|
||||
return MTL::LanguageVersion3_2;
|
||||
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
|
||||
return MTL::LanguageVersion3_1;
|
||||
} else {
|
||||
return MTL::LanguageVersion3_0;
|
||||
}
|
||||
};
|
||||
static auto metal_version_ = get_metal_version_();
|
||||
return metal_version_;
|
||||
constexpr auto get_metal_version() {
|
||||
#if (MLX_METAL_VERSION >= 320)
|
||||
return MTL::LanguageVersion3_2;
|
||||
#elif (MLX_METAL_VERSION >= 310)
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
#endif
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
@@ -175,14 +171,14 @@ void CommandEncoder::maybeInsertBarrier() {
|
||||
next_outputs_.clear();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatch_threadgroups(
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatch_threads(
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
@@ -280,7 +276,7 @@ void Device::end_encoding(int index) {
|
||||
// - Update the map of outputs to include this command encoder's outputs.
|
||||
// - Always signal this command encoders fence.
|
||||
// - Add a completion handler for this command encoder that removes outputs
|
||||
// from the map to limit the growth of the map and avoid unnecessary waits
|
||||
// from the map to limit the growth of the map and avoid unecessary waits
|
||||
// - Temporaries are a special case as they do not cross command encoder
|
||||
// boundaries. These can be removed early from the encoders inputs and
|
||||
// outputs since they don't need synchronization.
|
||||
@@ -302,7 +298,7 @@ void Device::end_encoding(int index) {
|
||||
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
|
||||
// If we've already waited on a fence, don't wait on it again.
|
||||
if (waiting_on.find(it->second) == waiting_on.end()) {
|
||||
enc.wait_for_fence(it->second->fence);
|
||||
enc->waitForFence(it->second->fence);
|
||||
waiting_on.insert(it->second);
|
||||
}
|
||||
}
|
||||
@@ -311,7 +307,7 @@ void Device::end_encoding(int index) {
|
||||
stream.outputs[out] = stream.fence;
|
||||
}
|
||||
}
|
||||
enc.update_fence(stream.fence->fence);
|
||||
enc->updateFence(stream.fence->fence);
|
||||
stream.buffer->addCompletedHandler(
|
||||
[&stream,
|
||||
waiting_on = std::move(waiting_on),
|
||||
@@ -645,27 +641,21 @@ void new_stream(Stream stream) {
|
||||
|
||||
std::unordered_map<std::string, std::variant<std::string, size_t>>
|
||||
device_info() {
|
||||
auto init_device_info = []()
|
||||
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
|
||||
int mib[] = {CTL_HW, HW_MEMSIZE};
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
int mib[] = {CTL_HW, HW_MEMSIZE};
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
|
||||
sysctl(mib, 2, &memsize, &length, NULL, 0);
|
||||
sysctl(mib, 2, &memsize, &length, NULL, 0);
|
||||
|
||||
return {
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize}};
|
||||
};
|
||||
static auto device_info_ = init_device_info();
|
||||
return device_info_;
|
||||
return {
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -58,43 +58,16 @@ struct CommandEncoder {
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
MTL::ComputeCommandEncoder* operator->() {
|
||||
return enc_;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void maybeInsertBarrier();
|
||||
|
||||
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
|
||||
enc_->setComputePipelineState(kernel);
|
||||
}
|
||||
|
||||
void wait_for_fence(MTL::Fence* fence) {
|
||||
enc_->waitForFence(fence);
|
||||
}
|
||||
|
||||
void update_fence(MTL::Fence* fence) {
|
||||
enc_->updateFence(fence);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_bytes(const T* v, int n, int idx) {
|
||||
return enc_->setBytes(v, n * sizeof(T), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_bytes(const T& v, int idx) {
|
||||
return enc_->setBytes(&v, sizeof(T), idx);
|
||||
}
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
@@ -363,7 +363,7 @@ void multi_upload_bluestein_fft(
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
|
||||
// Broadcast w_q and w_k to the batch size
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast({}, complex64, nullptr, {});
|
||||
array w_q_broadcast({}, complex64, nullptr, {});
|
||||
@@ -386,8 +386,8 @@ void multi_upload_bluestein_fft(
|
||||
copies.push_back(slice_temp);
|
||||
copies.push_back(conj_temp);
|
||||
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
||||
@@ -431,19 +431,19 @@ void multi_upload_bluestein_fft(
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
Shape starts(in.ndim(), 0);
|
||||
Shape strides(in.ndim(), 1);
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> strides(in.ndim(), 1);
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
slice_gpu(temp1, out, rstarts, strides, s);
|
||||
} else if (real && inverse) {
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
auto inv_n = array({1.0f / n}, {1}, float32);
|
||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||
copies.push_back(temp_float);
|
||||
@@ -531,8 +531,8 @@ void fft_op(
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
Strides strides;
|
||||
int64_t cur_stride = x.shape(axis);
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axis);
|
||||
for (int a = 0; a < x.ndim(); a++) {
|
||||
if (a == axis) {
|
||||
strides.push_back(1);
|
||||
@@ -699,7 +699,7 @@ void fft_op(
|
||||
auto kernel =
|
||||
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
@@ -711,9 +711,9 @@ void fft_op(
|
||||
|
||||
compute_encoder.set_input_array(w_q, 2); // w_q
|
||||
compute_encoder.set_input_array(w_k, 3); // w_k
|
||||
compute_encoder.set_bytes(n, 4);
|
||||
compute_encoder.set_bytes(plan.bluestein_n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
} else if (plan.rader_n > 1) {
|
||||
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
|
||||
copies.push_back(b_q);
|
||||
@@ -723,22 +723,22 @@ void fft_op(
|
||||
compute_encoder.set_input_array(b_q, 2);
|
||||
compute_encoder.set_input_array(g_q, 3);
|
||||
compute_encoder.set_input_array(g_minus_q, 4);
|
||||
compute_encoder.set_bytes(n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
compute_encoder.set_bytes(plan.rader_n, 7);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
|
||||
} else if (four_step_params.required) {
|
||||
compute_encoder.set_bytes(four_step_params.n1, 2);
|
||||
compute_encoder.set_bytes(four_step_params.n2, 3);
|
||||
compute_encoder.set_bytes(total_batch_size, 4);
|
||||
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
|
||||
} else {
|
||||
compute_encoder.set_bytes(n, 2);
|
||||
compute_encoder.set_bytes(total_batch_size, 3);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
|
||||
}
|
||||
|
||||
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
|
||||
auto grid_dims =
|
||||
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@@ -777,7 +777,7 @@ void nd_fft_op(
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
|
@@ -137,14 +137,14 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(scale, 2);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
};
|
||||
|
||||
if (m > 1) {
|
||||
|
@@ -53,31 +53,27 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
bool large_index = nidx && inputs[1].size() > INT32_MAX;
|
||||
bool large_src = src.size() > INT32_MAX;
|
||||
bool large_out = out.size() > INT32_MAX;
|
||||
bool large = large_index || large_src || large_out;
|
||||
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
std::string kernel_name = fmt::format(
|
||||
"gather{0}{1}_{2}_{3}_{4}",
|
||||
type_to_name(out),
|
||||
idx_type_name,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
|
||||
<< "_" << idx_ndim;
|
||||
lib_name = kname.str();
|
||||
kernel_name = lib_name;
|
||||
}
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::gather();
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gather();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
// Index dimension specializations
|
||||
kernel_source += fmt::format(
|
||||
kernel_source << fmt::format(
|
||||
gather_kernels,
|
||||
type_to_name(out) + idx_type_name,
|
||||
out_type_str,
|
||||
@@ -85,14 +81,13 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "int");
|
||||
return kernel_source;
|
||||
idx_ndim);
|
||||
return kernel_source.str();
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
@@ -136,20 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder.set_vector_bytes(src.shape(), 2);
|
||||
compute_encoder.set_vector_bytes(src.strides(), 3);
|
||||
compute_encoder.set_bytes(ndim, 4);
|
||||
compute_encoder.set_vector_bytes(slice_sizes_, 5);
|
||||
compute_encoder.set_vector_bytes(axes_, 6);
|
||||
set_vector_bytes(compute_encoder, src.shape(), 2);
|
||||
set_vector_bytes(compute_encoder, src.strides(), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||
set_vector_bytes(compute_encoder, slice_sizes_, 5);
|
||||
set_vector_bytes(compute_encoder, axes_, 6);
|
||||
|
||||
// Set index info
|
||||
//
|
||||
// We don't need to check for empty idx_shapes because gather has a
|
||||
// idx_ndim == 0 specialization
|
||||
compute_encoder.set_vector_bytes(idx_shapes, 7);
|
||||
compute_encoder.set_vector_bytes(idx_strides, 8);
|
||||
compute_encoder.set_vector_bytes(idx_contigs, 9);
|
||||
compute_encoder.set_bytes(idx_ndim, 10);
|
||||
set_vector_bytes(compute_encoder, idx_shapes, 7);
|
||||
set_vector_bytes(compute_encoder, idx_strides, 8);
|
||||
set_vector_bytes(compute_encoder, idx_contigs, 9);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@@ -157,7 +152,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -214,6 +209,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nwork = 32;
|
||||
}
|
||||
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
@@ -234,24 +231,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
break;
|
||||
}
|
||||
auto upd_contig = upd.flags().row_contiguous;
|
||||
bool large_out = out.size() > INT32_MAX;
|
||||
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
|
||||
bool large_upd = upd.size() > INT32_MAX;
|
||||
bool large = large_out || large_idx || large_upd;
|
||||
std::string kernel_name = fmt::format(
|
||||
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
|
||||
type_to_name(out),
|
||||
idx_type_name,
|
||||
op_name,
|
||||
nidx,
|
||||
upd_contig ? "updc_true" : "updc_false",
|
||||
nwork,
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
kname << "_" << op_name << "_" << nidx << "_"
|
||||
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
|
||||
lib_name = kname.str();
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< metal::scatter();
|
||||
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
@@ -279,7 +270,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
kernel_source += fmt::format(
|
||||
kernel_source << fmt::format(
|
||||
scatter_kernels,
|
||||
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||
out_type_str,
|
||||
@@ -289,9 +280,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_args,
|
||||
idx_arr,
|
||||
upd_contig,
|
||||
nwork,
|
||||
large ? "int64_t" : "int");
|
||||
return kernel_source;
|
||||
nwork);
|
||||
return kernel_source.str();
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -299,7 +289,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
size_t nthreads = upd.size();
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set all the buffers
|
||||
compute_encoder.set_input_array(upd, 1);
|
||||
@@ -312,8 +302,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
// Collect all idx shapes and strides into one place
|
||||
Shape idx_shapes;
|
||||
Strides idx_strides;
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
// To access .data() use char instead of bool
|
||||
// bool is 1 byte in Metal so this is safe
|
||||
std::vector<char> idx_contigs;
|
||||
@@ -332,30 +322,30 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 3);
|
||||
compute_encoder.set_bytes(stride_, 4);
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder.set_vector_bytes(upd.shape(), 3);
|
||||
compute_encoder.set_vector_bytes(upd.strides(), 4);
|
||||
set_vector_bytes(compute_encoder, upd.shape(), 3);
|
||||
set_vector_bytes(compute_encoder, upd.strides(), 4);
|
||||
}
|
||||
compute_encoder.set_bytes(upd_ndim, 5);
|
||||
compute_encoder.set_bytes(upd_size, 6);
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 7);
|
||||
compute_encoder.set_bytes(stride_, 8);
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder.set_vector_bytes(out.shape(), 7);
|
||||
compute_encoder.set_vector_bytes(out.strides(), 8);
|
||||
set_vector_bytes(compute_encoder, out.shape(), 7);
|
||||
set_vector_bytes(compute_encoder, out.strides(), 8);
|
||||
}
|
||||
compute_encoder.set_bytes(out_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(axes_, 10);
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
@@ -365,11 +355,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_strides.push_back(0);
|
||||
idx_contigs.push_back(false);
|
||||
}
|
||||
compute_encoder.set_vector_bytes(idx_shapes, 11);
|
||||
compute_encoder.set_vector_bytes(idx_strides, 12);
|
||||
compute_encoder.set_vector_bytes(idx_contigs, 13);
|
||||
compute_encoder.set_bytes(idx_ndim, 14);
|
||||
compute_encoder.set_bytes(idx_size, 15);
|
||||
set_vector_bytes(compute_encoder, idx_shapes, 11);
|
||||
set_vector_bytes(compute_encoder, idx_strides, 12);
|
||||
set_vector_bytes(compute_encoder, idx_contigs, 13);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
|
||||
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@@ -385,7 +375,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -11,13 +11,13 @@ gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device {outm_t}* out_mask [[buffer(20)]],
|
||||
const device {opm_t}* mat_mask [[buffer(21)]],
|
||||
const device {opm_t}* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
|
@@ -1,16 +1,16 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gather_kernels = R"(
|
||||
[[kernel]] void gather{0}_{3}_{6}_{7}(
|
||||
[[kernel]] void gather{0}_{3}_{6}(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant int64_t* src_strides [[buffer(3)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant int64_t* idx_strides [[buffer(8)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant bool* idx_contigs [[buffer(9)]],
|
||||
const constant int& idx_ndim [[buffer(10)]],
|
||||
{4}
|
||||
@@ -19,7 +19,7 @@ constexpr std::string_view gather_kernels = R"(
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}, {7}>(
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
src,
|
||||
out,
|
||||
src_shape,
|
||||
@@ -34,19 +34,19 @@ constexpr std::string_view gather_kernels = R"(
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
|
||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant int64_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant int64_t* out_strides [[buffer(8)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant int64_t* idx_strides [[buffer(12)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant bool* idx_contigs [[buffer(13)]],
|
||||
const constant int& idx_ndim [[buffer(14)]],
|
||||
const constant size_t& idx_size [[buffer(15)]],
|
||||
@@ -54,7 +54,7 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
|
||||
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
|
@@ -10,12 +10,12 @@ template [[host_name("{name}")]]
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
@@ -43,7 +43,7 @@ block_masked_gemm<
|
||||
device {itype}* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
||||
|
@@ -1,4 +1,5 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
@@ -45,27 +46,25 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
||||
kernel_source << get_template_definition(
|
||||
"v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
void append_binary_kernels(
|
||||
void add_binary_kernels(
|
||||
const std::string lib_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::string& kernel_source) {
|
||||
std::ostringstream& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
@@ -74,27 +73,27 @@ void append_binary_kernels(
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1large", "binary_g_nd1"},
|
||||
{"g2large", "binary_g_nd2"},
|
||||
{"g3large", "binary_g_nd3"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
}};
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
std::string template_def;
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op);
|
||||
kernel_source << template_def;
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name,
|
||||
"binary_g",
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op,
|
||||
4);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
@@ -105,11 +104,10 @@ MTL::ComputePipelineState* get_binary_kernel(
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::binary_ops(), metal::binary());
|
||||
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -122,10 +120,11 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::binary_ops(), metal::binary_two());
|
||||
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops()
|
||||
<< metal::binary_two();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -137,31 +136,24 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto t_str = get_type_string(type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||
std::ostringstream kernel_source;
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
{"g3large", "ternary_g_nd3"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
{"g3", "ternary_g_nd3"},
|
||||
}};
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
std::string template_def;
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op);
|
||||
kernel_source << template_def;
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
|
||||
return kernel_source;
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -173,47 +165,31 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::copy();
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source +=
|
||||
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g1large_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
|
||||
kernel_source += get_template_definition(
|
||||
"gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
return kernel_source;
|
||||
kernel_source << metal::utils() << metal::copy()
|
||||
<< get_template_definition(
|
||||
"s_" + lib_name, "copy_s", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"v_" + lib_name, "copy_v", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||
<< get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -345,17 +321,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const Dtype& out_type) {
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::reduce_utils();
|
||||
kernel_source += metal::reduce();
|
||||
kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
|
||||
return kernel_source;
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, out_type, op);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -365,31 +341,30 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const Dtype& in_type,
|
||||
const Dtype& out_type,
|
||||
const std::string& idx_t,
|
||||
const array& in,
|
||||
const array& out,
|
||||
int ndim /* = -1 */,
|
||||
int bm /* = -1 */,
|
||||
int bn /* = -1 */) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
if (bm >= 0) {
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
|
||||
} else if (ndim >= 0) {
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim);
|
||||
} else {
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t);
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op);
|
||||
}
|
||||
return kernel_source;
|
||||
return kernel_source.str();
|
||||
});
|
||||
auto st = d.get_kernel(kernel_name, lib);
|
||||
return st;
|
||||
|
@@ -81,16 +81,15 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const Dtype& out_type);
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const Dtype& in_type,
|
||||
const Dtype& out_type,
|
||||
const std::string& idx_t,
|
||||
const array& in,
|
||||
const array& out,
|
||||
int ndim = -1,
|
||||
int bm = -1,
|
||||
int bn = -1);
|
||||
|
@@ -1,27 +1,13 @@
|
||||
set(BASE_HEADERS
|
||||
metal_3_1/bf16.h
|
||||
metal_3_0/bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
expm1f.h
|
||||
utils.h)
|
||||
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
if(MLX_METAL_VERSION GREATER_EQUAL 310)
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
COMMENT "Building ${TARGET}.air"
|
||||
@@ -44,7 +30,9 @@ build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(scaled_dot_product_attention sdpa_vector.h)
|
||||
build_kernel(
|
||||
scaled_dot_product_attention scaled_dot_product_attention_params.h
|
||||
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
|
||||
|
||||
set(STEEL_HEADERS
|
||||
steel/defines.h
|
||||
@@ -66,24 +54,6 @@ set(STEEL_HEADERS
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
set(STEEL_ATTN_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h
|
||||
steel/attn/attn.h
|
||||
steel/attn/loader.h
|
||||
steel/attn/mma.h
|
||||
steel/attn/params.h
|
||||
steel/attn/transforms.h
|
||||
steel/attn/kernels/steel_attention.h)
|
||||
|
||||
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
|
||||
|
||||
if(NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
build_kernel(binary binary.h binary_ops.h)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
|
@@ -75,10 +75,10 @@ template <typename T, typename Op, int N_READS = 4>
|
||||
const device T* in [[buffer(0)]],
|
||||
device uint32_t* out [[buffer(1)]],
|
||||
const constant int* shape [[buffer(2)]],
|
||||
const constant int64_t* in_strides [[buffer(3)]],
|
||||
const constant int64_t* out_strides [[buffer(4)]],
|
||||
const constant size_t* in_strides [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& ndim [[buffer(5)]],
|
||||
const constant int64_t& axis_stride [[buffer(6)]],
|
||||
const constant size_t& axis_stride [[buffer(6)]],
|
||||
const constant size_t& axis_size [[buffer(7)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
|
@@ -6,6 +6,12 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
#else
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -305,10 +311,7 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||
return x.bits_;
|
||||
}
|
||||
|
||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
#endif
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal math for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -367,6 +369,18 @@ instantiate_metal_math_funcs(
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
||||
#else
|
||||
|
||||
#define bfloat16_to_uint16(x) x.bits_
|
||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||
|
||||
#endif
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_simd_comm_funcs(
|
||||
|
@@ -43,7 +43,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
c[offset] = Op()(a[0], b[offset]);
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[0]);
|
||||
}
|
||||
|
||||
@@ -65,75 +65,72 @@ template <typename T, typename U, typename Op>
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int64_t& a_stride,
|
||||
constant const int64_t& b_stride,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int64_t a_strides[2],
|
||||
constant const int64_t b_strides[2],
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int64_t a_strides[3],
|
||||
constant const int64_t b_strides[3],
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = int64_t>
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int* shape,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
|
||||
idx.x += a_xstride;
|
||||
|
@@ -9,22 +9,18 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
|
@@ -56,7 +56,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto out = Op()(a[0], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
@@ -70,7 +70,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto out = Op()(a[offset], b[0]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
@@ -84,87 +84,84 @@ template <typename T, typename U, typename Op>
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto out = Op()(a[offset], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int64_t& a_stride,
|
||||
constant const int64_t& b_stride,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int64_t a_strides[2],
|
||||
constant const int64_t b_strides[2],
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int64_t a_strides[3],
|
||||
constant const int64_t b_strides[3],
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = int64_t>
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int* shape,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
IdxT a_xstride = a_strides[ndim - 1];
|
||||
IdxT b_xstride = b_strides[ndim - 1];
|
||||
size_t out_idx =
|
||||
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto a_xstride = a_strides[ndim - 1];
|
||||
auto b_xstride = b_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
|
@@ -7,22 +7,18 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
|
@@ -4,8 +4,8 @@
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
|
@@ -22,7 +22,7 @@ template <typename T, typename U>
|
||||
device U* dst [[buffer(1)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
@@ -32,46 +32,46 @@ template <typename T, typename U>
|
||||
device U* dst [[buffer(1)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
||||
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
||||
IdxT dst_idx =
|
||||
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
template <typename T, typename U, int N = 1>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
@@ -80,16 +80,17 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc<IdxT>(
|
||||
auto src_idx = elem_to_loc(
|
||||
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
||||
if (N == 1) {
|
||||
IdxT dst_idx =
|
||||
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
int64_t dst_idx =
|
||||
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
return;
|
||||
}
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
int64_t dst_idx =
|
||||
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
|
||||
auto src_xstride = src_strides[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
|
||||
@@ -97,43 +98,43 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
template <typename T, typename U, int N = 1>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
@@ -142,7 +143,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
auto idx = elem_to_loc_2_nd(
|
||||
{N * index.x, index.y, index.z},
|
||||
src_shape,
|
||||
src_strides,
|
||||
@@ -152,8 +153,8 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||
return;
|
||||
}
|
||||
IdxT src_xstride = src_strides[ndim - 1];
|
||||
IdxT dst_xstride = dst_strides[ndim - 1];
|
||||
auto src_xstride = src_strides[ndim - 1];
|
||||
auto dst_xstride = dst_strides[ndim - 1];
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||
|
@@ -2,29 +2,22 @@
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
|
@@ -4,30 +4,30 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant int64_t* src_strides [[buffer(3)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
LocT src_idx = 0;
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc;
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc += indices.row_contiguous[i]
|
||||
? index.y
|
||||
: elem_to_loc<LocT>(
|
||||
: elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
@@ -35,17 +35,17 @@ METAL_FUNC void gather_impl(
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset =
|
||||
elem_to_loc<LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
||||
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
LocT out_idx = index.z;
|
||||
size_t out_idx = index.z;
|
||||
if (IDX_NDIM == 1) {
|
||||
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
|
||||
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
|
||||
} else if (IDX_NDIM >= 2) {
|
||||
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
|
||||
out_idx +=
|
||||
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
|
||||
}
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
@@ -3,6 +3,8 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
@@ -436,9 +438,9 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@@ -486,21 +488,31 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
instantiate_kernel( \
|
||||
"gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
||||
gemv, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
sm, \
|
||||
sn, \
|
||||
tm, \
|
||||
tn, \
|
||||
nc, \
|
||||
axpby)
|
||||
#define instantiate_gemv_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc \
|
||||
"_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
@@ -539,13 +551,13 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant int64_t* index_batch_strides [[buffer(11)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -561,8 +573,8 @@ template <
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant auto* veci_bstrides = index_batch_strides;
|
||||
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
@@ -609,14 +621,37 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_kernel( \
|
||||
"gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
||||
gemv_gather, itype, bm, bn, sm, sn, tm, tn)
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
|
||||
gemv_gather<itype, bm, bn, sm, sn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_bs_blocks(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_gemv_bs_blocks(name, itype) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
|
||||
@@ -651,9 +686,9 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@@ -701,14 +736,33 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
instantiate_kernel( \
|
||||
"gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
||||
gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)
|
||||
#define instantiate_gemv_t_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc \
|
||||
"_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv_t<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
|
||||
@@ -748,13 +802,13 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant int64_t* index_batch_strides [[buffer(11)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -770,8 +824,8 @@ template <
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant auto* veci_bstrides = index_batch_strides;
|
||||
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
@@ -818,14 +872,36 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_bs_helper( \
|
||||
nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_kernel( \
|
||||
"gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
||||
gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)
|
||||
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
|
||||
gemv_t_gather<itype, bm, bn, sm, sn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_bs_blocks(name, itype) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \
|
||||
@@ -836,4 +912,4 @@ template <
|
||||
// clang-format off
|
||||
instantiate_gemv_t_bs_blocks(float32, float);
|
||||
instantiate_gemv_t_bs_blocks(float16, half);
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
@@ -642,13 +642,13 @@ template <
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -673,8 +673,8 @@ template <
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant auto* mask_strides_mat = mask_batch_strides;
|
||||
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
@@ -742,13 +742,13 @@ template <
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -773,8 +773,8 @@ template <
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant auto* mask_strides_mat = mask_batch_strides;
|
||||
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
|
@@ -4,17 +4,37 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/gemv_masked.h"
|
||||
|
||||
#define instantiate_gemv_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_kernel( \
|
||||
"gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc, \
|
||||
gemv_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
|
||||
template [[host_name("gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
|
||||
gemv_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const device outm_t* out_mask [[buffer(20)]], \
|
||||
const device opm_t* mat_mask [[buffer(21)]], \
|
||||
const device opm_t* vec_mask [[buffer(22)]], \
|
||||
const constant int* mask_strides [[buffer(23)]], \
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
@@ -43,11 +63,29 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
#define instantiate_gemv_t_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_kernel( \
|
||||
"gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc, \
|
||||
gemv_t_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
|
||||
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
|
||||
gemv_t_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const device outm_t* out_mask [[buffer(20)]], \
|
||||
const device opm_t* mat_mask [[buffer(21)]], \
|
||||
const device opm_t* vec_mask [[buffer(22)]], \
|
||||
const constant int* mask_strides [[buffer(23)]], \
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
|
@@ -8,13 +8,13 @@ template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant int64_t* strides;
|
||||
const constant size_t* strides;
|
||||
const constant bool* row_contiguous;
|
||||
const int ndim;
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
if (is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
|
@@ -1,16 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#define jit_if #if
|
||||
#define jit_else #else
|
||||
#define jit_endif #endif
|
||||
|
||||
jit_if (__METAL_VERSION__ >= 310)
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
|
||||
|
||||
jit_else
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
|
||||
|
||||
jit_endif // clang-format on
|
@@ -3,6 +3,8 @@
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
@@ -1,16 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||
return as_type<uint16_t>(x);
|
||||
}
|
||||
|
||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||
return as_type<bfloat16_t>(x);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user