Compare commits

..

12 Commits

Author SHA1 Message Date
Angelos Katharopoulos
d82699f0f1 Merge branch 'distributed-layers' into socket-distributed-layers 2024-11-05 11:36:16 -08:00
Angelos Katharopoulos
6fc00d2c10 Add rudimentary barrier 2024-11-05 11:34:55 -08:00
Angelos Katharopoulos
44f0de2854 Fix run without distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
29ec3539ed TCP socket distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e94f0028c3 Change the send message size 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e5354fcddb Make it work even for donated inputs 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
34dd079a64 Start a sockets based distributed backend 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
16975815e9 Fixes in distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
a8b3da7946 Add distributed layers to nn top-level 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
060e1c9f92 Add quantized distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
0b04742985 Add the distributed linear layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
c3ccd4919f Add MPI barrier 2024-11-05 11:26:53 -08:00
250 changed files with 8355 additions and 11728 deletions

3
.gitignore vendored
View File

@@ -76,9 +76,6 @@ build/
*.out
*.app
# Debug symbols
*.pdb
# VSCode
.vscode/
.DS_Store

View File

@@ -1,14 +1,13 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.4
rev: v18.1.8
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.10.0
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:

View File

@@ -20,12 +20,11 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.21.1)
set(MLX_VERSION 0.19.3)
endif()
# --------------------- Processor tests -------------------------
@@ -35,6 +34,8 @@ message(
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
set(MLX_BUILD_ARM OFF)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
@@ -56,6 +57,10 @@ else()
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
# ----------------------------- Lib -----------------------------
include(FetchContent)
@@ -84,26 +89,25 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
if(${MACOS_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
# Get the metal version
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
@@ -111,54 +115,20 @@ elseif(MLX_BUILD_METAL)
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
set(MLX_BUILD_GGUF OFF)
# There is no prebuilt OpenBLAS distribution for MSVC.
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
endif()
# Windows implementation of dlfcn.h APIs.
FetchContent_Declare(
dlfcn-win32
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(dlfcn-win32)
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
target_link_libraries(mlx PRIVATE dl)
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif()
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code.
FetchContent_Declare(
openblas
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
GIT_TAG v0.3.28
EXCLUDE_FROM_ALL)
set(BUILD_STATIC_LIBS ON) # link statically
set(NOFORTRAN ON) # msvc has no fortran compiler
FetchContent_MakeAvailable(openblas)
target_link_libraries(mlx PRIVATE openblas)
target_include_directories(
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
else()
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
@@ -176,7 +146,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
@@ -189,7 +159,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
endif()
else()
set(MLX_BUILD_ACCELERATE OFF)

View File

@@ -5,35 +5,35 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_value_and_grad() {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
auto x = ones({200, 1000});
eval(x);
auto fn = [](array x) {
for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x));
x = log(exp(x));
}
return mx::sum(x);
return sum(x);
};
auto grad_fn = mx::grad(fn);
auto grad_fn = grad(fn);
auto independent_value_and_grad = [&]() {
auto value = fn(x);
auto dfdx = grad_fn(x);
return std::vector<mx::array>{value, dfdx};
return std::vector<array>{value, dfdx};
};
TIME(independent_value_and_grad);
auto value_and_grad_fn = mx::value_and_grad(fn);
auto value_and_grad_fn = value_and_grad(fn);
auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<mx::array>{value, dfdx};
return std::vector<array>{value, dfdx};
};
TIME(combined_value_and_grad);
}
int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
std::cout << "Benchmarks for " << default_device() << std::endl;
time_value_and_grad();
}

View File

@@ -4,21 +4,21 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_add_op() {
std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back());
}
set_default_device(mx::Device::cpu);
set_default_device(Device::cpu);
for (auto size : sizes) {
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
std::cout << "Size " << size << std::endl;
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
TIMEM("cpu", add, a, b, Device::cpu);
TIMEM("gpu", add, a, b, Device::gpu);
}
}

View File

@@ -6,105 +6,105 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_irregular_binary_ops_1D() {
auto device = mx::default_device();
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", mx::add, a, b, device);
TIMEM("1D strided", add, a, b, device);
}
void time_irregular_binary_ops_2D() {
auto device = mx::default_device();
auto device = default_device();
int size = 2048;
auto a = mx::random::uniform({size, size});
auto b = mx::random::uniform({size, size});
mx::eval(a, b);
TIMEM("2D regular", mx::add, a, b, device);
auto a = random::uniform({size, size});
auto b = random::uniform({size, size});
eval(a, b);
TIMEM("2D regular", add, a, b, device);
b = mx::transpose(b);
mx::eval(b);
TIMEM("2D mx::transpose", mx::add, a, b, device);
b = transpose(b);
eval(b);
TIMEM("2D transpose", add, a, b, device);
b = mx::random::uniform({size});
mx::eval(b);
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({size});
eval(b);
TIMEM("2D broadcast dim 0", add, a, b, device);
b = mx::reshape(b, {size, 1});
mx::eval(b);
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
b = reshape(b, {size, 1});
eval(b);
TIMEM("2D broadcast dim 1", add, a, b, device);
}
void time_irregular_binary_ops_3D() {
auto device = mx::default_device();
auto device = default_device();
int d0 = 32;
int d1 = 512;
int d2 = 512;
auto a = mx::random::uniform({d0, d1, d2});
auto b = mx::random::uniform({d0, d1, d2});
TIMEM("3D regular", mx::add, a, b, device);
auto a = random::uniform({d0, d1, d2});
auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", add, a, b, device);
b = mx::transpose(b, {0, 2, 1});
TIMEM("3D mx::transpose", mx::add, a, b, device);
b = transpose(b, {0, 2, 1});
TIMEM("3D transpose", add, a, b, device);
b = mx::random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", add, a, b, device);
b = mx::random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", add, a, b, device);
b = mx::random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
b = mx::random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
b = mx::random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
}
void time_irregular_binary_ops_4D() {
auto device = mx::default_device();
auto device = default_device();
std::vector<int> shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
auto a = random::uniform(shape);
auto b = random::uniform(shape);
TIMEM("4D regular", mx::add, a, b, device);
TIMEM("4D regular", add, a, b, device);
b = mx::transpose(b, {0, 1, 3, 2});
TIMEM("4D mx::transpose", mx::add, a, b, device);
b = transpose(b, {0, 1, 3, 2});
TIMEM("4D transpose", add, a, b, device);
std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1;
b = mx::random::uniform(shape);
b = random::uniform(shape);
std::ostringstream msg;
msg << om << i;
TIMEM(msg.str(), mx::add, a, b, device);
TIMEM(msg.str(), add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1;
std::ostringstream msg;
msg << om << i << ", " << j;
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1;
std::ostringstream msg;
msg << om << i << ", " << j << ", " << k;
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[k] = a.shape(k);
}
}
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
}
void time_irregular_reshape() {
auto device = mx::default_device();
auto device = default_device();
std::vector<int> shape;
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
auto reshape_fn = [&shape, device](const array& a) {
return reshape(a, shape, device);
};
int size = 64;
int d = 2 * size;
auto a = mx::random::uniform({d, d, d});
auto a = random::uniform({d, d, d});
shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a);
a = mx::transpose(a);
a = transpose(a);
shape = {8 * size, size, size};
TIMEM("3D mx::transpose", reshape_fn, a);
TIMEM("3D transpose", reshape_fn, a);
a = mx::transpose(a, {1, 2, 0});
a = transpose(a, {1, 2, 0});
shape = {8 * size, size, size};
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
TIMEM("3D transpose dims 1 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
a = broadcast_to(random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
a = broadcast_to(random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
}
void time_irregular_astype_1D() {
auto device = mx::default_device();
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = mx::random::uniform({size});
auto a = random::uniform({size});
a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", mx::astype, a, mx::int32, device);
TIMEM("1D strided", astype, a, int32, device);
}
void time_irregular_astype_2D() {
auto device = mx::default_device();
auto device = default_device();
int size = 2048;
std::vector<int> shape = {size, size};
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);
auto a = random::uniform(shape);
TIMEM("2D regular", astype, a, int32, device);
a = mx::transpose(a);
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
a = transpose(a);
TIMEM("2D transpose", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", astype, a, int32, device);
}
int main(int argc, char** argv) {
if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
set_default_device(use_gpu ? Device::gpu : Device::cpu);
}
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
std::cout << "Benchmarks for " << default_device() << std::endl;
time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D();

View File

@@ -3,20 +3,20 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_creation_ops() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
auto full_fp32 = [&]() { return full(shape, 3.3f); };
TIME(full_fp32);
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
TIME(zeros_fp32);
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
auto ones_fp32 = [&]() { return ones(shape, float32); };
TIME(ones_fp32);
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32);
}
@@ -24,196 +24,194 @@ void time_type_conversions() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto device = mx::default_device();
auto device = default_device();
auto a = mx::zeros(shape, mx::float32);
mx::eval(a);
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
auto a = zeros(shape, float32);
eval(a);
TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("float32 to uint32", astype, a, uint32, device);
a = mx::zeros(shape, mx::int32);
mx::eval(a);
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
a = zeros(shape, int32);
eval(a);
TIMEM("int32 to float32", astype, a, float32, device);
a = mx::zeros(shape, mx::bool_);
mx::eval(a);
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
a = zeros(shape, bool_);
eval(a);
TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to uint32", astype, a, uint32, device);
}
void time_random_generation() {
int M = 2000;
int N = 500;
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
auto uniform = [&]() { return random::uniform({M, N}, float32); };
TIME(uniform);
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
auto normal = [&]() { return random::normal({M, N}, float32); };
TIME(normal);
}
void time_unary_ops() {
int M = 2000;
int N = 500;
auto device = mx::default_device();
auto device = default_device();
auto a = mx::random::normal({M, N});
mx::eval(a);
auto a = random::normal({M, N});
eval(a);
TIME(mlx::core::abs, a, device);
TIME(mx::negative, a, device);
TIME(mx::sign, a, device);
TIME(mx::square, a, device);
TIME(negative, a, device);
TIME(sign, a, device);
TIME(square, a, device);
TIME(mlx::core::sqrt, a, device);
TIME(mx::rsqrt, a, device);
TIME(rsqrt, a, device);
TIME(mlx::core::exp, a, device);
a = mx::random::uniform({M, N});
a = random::uniform({M, N});
TIME(mlx::core::log, a, device);
}
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = mx::random::randint(0, 2, {M, N, K});
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(mx::add, a, b, device);
TIME(mx::subtract, a, b, device);
TIME(mx::multiply, a, b, device);
TIME(mx::divide, a, b, device);
TIME(mx::maximum, a, b, device);
TIME(mx::minimum, a, b, device);
TIME(mx::where, condition, a, b, device);
TIME(add, a, b, device);
TIME(subtract, a, b, device);
TIME(multiply, a, b, device);
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
condition = mx::array({true});
b = mx::random::uniform({1});
mx::eval(b);
TIMEM("scalar", mx::add, a, b, device);
TIMEM("vector-scalar", mx::subtract, a, b, device);
TIMEM("scalar-vector", mx::subtract, b, a, device);
TIMEM("scalar", mx::multiply, a, b, device);
TIMEM("vector-scalar", mx::divide, a, b, device);
TIMEM("scalar-vector", mx::divide, b, a, device);
TIMEM("scalar-vector", mx::where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
mx::eval(a, b);
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}
void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50;
auto a = mx::random::uniform({M, N, O, P});
auto b = mx::random::uniform({M, N, O, P});
auto device = mx::default_device();
mx::eval(a, b);
TIMEM("non-strided", mx::add, a, b, device);
a = mx::transpose(a, {1, 0, 2, 3});
b = mx::transpose(b, {3, 2, 0, 1});
mx::eval(a, b);
TIMEM("strided", mx::add, a, b, device);
auto a = random::uniform({M, N, O, P});
auto b = random::uniform({M, N, O, P});
auto device = default_device();
eval(a, b);
TIMEM("non-strided", add, a, b, device);
a = transpose(a, {1, 0, 2, 3});
b = transpose(b, {3, 2, 0, 1});
eval(a, b);
TIMEM("strided", add, a, b, device);
}
void time_comparisons() {
int M = 1000, N = 100, K = 10;
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::equal, a, b, device);
TIME(mx::greater, a, b, device);
TIME(mx::greater_equal, a, b, device);
TIME(mx::less, a, b, device);
TIME(mx::less_equal, a, b, device);
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(equal, a, b, device);
TIME(greater, a, b, device);
TIME(greater_equal, a, b, device);
TIME(less, a, b, device);
TIME(less_equal, a, b, device);
}
void time_matvec() {
int M = 2000, N = 200;
auto a = mx::random::uniform({M, N});
auto b = mx::random::uniform({N});
auto c = mx::random::uniform({M});
mx::eval(a, b, c);
auto matvec = [&]() { return mx::matmul(a, b); };
auto a = random::uniform({M, N});
auto b = random::uniform({N});
auto c = random::uniform({M});
eval(a, b, c);
auto matvec = [&]() { return matmul(a, b); };
TIME(matvec);
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
TIME(matvec_transpose);
}
void time_matmul() {
int M = 1000, N = 1000, K = 1000;
auto a = mx::random::uniform({M, K});
auto b = mx::random::uniform({K, N});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::matmul, a, b, device);
auto a = random::uniform({M, K});
auto b = random::uniform({K, N});
auto device = default_device();
eval(a, b);
TIME(matmul, a, b, device);
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
TIME(transpose_matmul);
}
void time_reductions() {
auto a = mx::random::normal({10000, 1000});
mx::eval(a);
auto sum_all = [&a]() { return mx::sum(a, false); };
auto a = random::normal({10000, 1000});
eval(a);
auto sum_all = [&a]() { return sum(a, false); };
TIME(sum_all);
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
TIME(sum_along_0);
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
TIME(sum_along_1);
auto prod_all = [&a]() { return mx::prod(a, false); };
auto prod_all = [&a]() { return prod(a, false); };
TIME(prod_all);
auto all_true = [&a]() { return mx::all(a, false); };
auto all_true = [&a]() { return all(a, false); };
TIME(all_true);
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
auto all_along_0 = [&a]() { return all(a, 0, false); };
TIME(all_along_0);
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
auto all_along_1 = [&a]() { return all(a, 1, false); };
TIME(all_along_1);
auto any_true = [&a]() { return mx::any(a, false); };
auto any_true = [&a]() { return any(a, false); };
TIME(any_true);
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
TIME(argmin_along_1);
}
void time_gather_scatter() {
auto a = mx::random::normal({1000, 768});
mx::eval(a);
auto indices = mx::random::randint(0, 1000, {256});
mx::eval(indices);
auto a = random::normal({1000, 768});
eval(a);
auto indices = random::randint(0, 1000, {256});
eval(indices);
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
TIME(embedding_lookup);
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
mx::eval(indices);
indices = random::randint(0, 768 * 1000, {256 * 768});
eval(indices);
auto single_element_lookup = [&a, &indices]() {
return mx::take(a, indices);
};
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
TIME(single_element_lookup);
indices = mx::random::randint(0, 1000, {256});
auto updates = mx::random::normal({256, 1, 768});
mx::eval(indices, updates);
indices = random::randint(0, 1000, {256});
auto updates = random::normal({256, 1, 768});
eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -225,10 +223,10 @@ void time_gather_scatter() {
};
TIME(embedding_add);
a = mx::reshape(a, {-1});
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
updates = mx::random::normal({256 * 768, 1});
mx::eval(a, indices, updates);
a = reshape(a, {-1});
indices = random::randint(0, 768 * 1000, {768 * 256});
updates = random::normal({256 * 768, 1});
eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -242,21 +240,21 @@ void time_gather_scatter() {
}
void time_divmod() {
auto a = mx::random::normal({1000});
auto b = mx::random::normal({1000});
mx::eval({a, b});
auto a = random::normal({1000});
auto b = random::normal({1000});
eval({a, b});
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops();
time_type_conversions();
time_unary_ops();

View File

@@ -1,74 +0,0 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
B = 1024
D = 1024
M = 4 * D
group_size = 64
bits = 4
dtype = mx.float16
loops = 10
def qmm_(x, wq1, wq2, q_type):
for i in range(loops):
x = mx.quantized_matmul(
x,
*wq1,
group_size=group_size,
bits=bits,
quantization_type=q_type,
)
x = mx.quantized_matmul(
x,
*wq2,
group_size=group_size,
bits=bits,
quantization_type=q_type,
)
return x
def affine_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine")
def affine_packed_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine-packed")
def time_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
mx.eval(x, wq1, wq2)
time_fn(affine_qmm, x, wq1, wq2)
def time_packed_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
mx.eval(x, wq1, wq2)
time_fn(affine_packed_qmm, x, wq1, wq2)
if __name__ == "__main__":
for b in [2, 4, 8]:
bits = b
print(f"Bits {bits}:")
time_qmm()
time_packed_qmm()

View File

@@ -1,189 +1,62 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
from time_utils import time_fn
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
MAX_SEQ = 300
START_SEQ = 100
SEQ_INCREMENT = 50
def bench(f, *args):
for i in range(N_warmup):
f(*args)
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(*args)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def sdpa_primitives(qs, ks, vs, alpha):
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ vs
return o
time_fn(sdpa_primitives, q, k, v, scale)
def mlx_sdpa_fused_inner(q, k, v, scale):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
scale = math.sqrt(1.0 / head_dim)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
time_fn(sdpa_fused, q, k, v, scale)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@@ -4,51 +4,42 @@ import math
import mlx.core as mx
from time_utils import time_fn
L = 16384
L = 1024
H = 32
H_k = H // 4
H_k = 32 // 4
D = 128
dtype = mx.float16
loops = 10
def attention(q, k, v):
def _sdpa(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
for i in range(loops):
q = _sdpa(q, k, v)
return q
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
def sdpa(q, k, v):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
return q
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(attention, q, k, v)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)

View File

@@ -420,8 +420,8 @@ element in the output.
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]],
constant const size_t* x_strides [[buffer(6)]],
constant const size_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array
@@ -438,10 +438,24 @@ each instantiation a unique host name so we can identify it.
.. code-block:: C++
instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
@@ -480,7 +494,7 @@ below.
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
@@ -495,14 +509,14 @@ below.
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
@@ -516,7 +530,7 @@ below.
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!

View File

@@ -209,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Metal kernel cache persists accross reboots.
Troubleshooting
^^^^^^^^^^^^^^^

View File

@@ -12,4 +12,5 @@ Fast
layer_norm
rope
scaled_dot_product_attention
affine_quantize
metal_kernel

View File

@@ -12,7 +12,6 @@ Layers
ALiBi
AvgPool1d
AvgPool2d
AvgPool3d
BatchNorm
CELU
Conv1d
@@ -42,7 +41,6 @@ Layers
LSTM
MaxPool1d
MaxPool2d
MaxPool3d
Mish
MultiHeadAttention
PReLU

View File

@@ -168,7 +168,6 @@ Operations
tri
tril
triu
unflatten
var
view
where

View File

@@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -4,19 +4,19 @@
#include "mlx/mlx.h"
namespace mx = mlx::core;
using namespace mlx::core;
int main() {
if (!mx::distributed::is_available()) {
if (!distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = mx::distributed::init();
auto global_group = distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
mx::array x = mx::ones({10});
mx::array out = mx::distributed::all_sum(x, global_group);
array x = ones({10});
array out = distributed::all_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@@ -10,7 +10,7 @@
/**
* An example of linear regression with MLX.
*/
namespace mx = mlx::core;
using namespace mlx::core;
int main() {
int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.01;
// True parameters
auto w_star = mx::random::normal({num_features});
auto w_star = random::normal({num_features});
// The input examples (design matrix)
auto X = mx::random::normal({num_examples, num_features});
auto X = random::normal({num_examples, num_features});
// Noisy labels
auto eps = 1e-2 * mx::random::normal({num_examples});
auto y = mx::matmul(X, w_star) + eps;
auto eps = 1e-2 * random::normal({num_examples});
auto y = matmul(X, w_star) + eps;
// Initialize random parameters
mx::array w = 1e-2 * mx::random::normal({num_features});
array w = 1e-2 * random::normal({num_features});
auto loss_fn = [&](mx::array w) {
auto yhat = mx::matmul(X, w);
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
auto loss_fn = [&](array w) {
auto yhat = matmul(X, w);
return (0.5f / num_examples) * sum(square(yhat - y));
};
auto grad_fn = mx::grad(loss_fn);
auto grad_fn = grad(loss_fn);
auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) {
auto grads = grad_fn(w);
w = w - learning_rate * grads;
mx::eval(w);
auto grad = grad_fn(w);
w = w - learning_rate * grad;
eval(w);
}
auto toc = timer::time();
auto loss = loss_fn(w);
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
<< ", Throughput " << throughput << " (it/s)." << std::endl;

View File

@@ -10,7 +10,7 @@
/**
* An example of logistic regression with MLX.
*/
namespace mx = mlx::core;
using namespace mlx::core;
int main() {
int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.1;
// True parameters
auto w_star = mx::random::normal({num_features});
auto w_star = random::normal({num_features});
// The input examples
auto X = mx::random::normal({num_examples, num_features});
auto X = random::normal({num_examples, num_features});
// Labels
auto y = mx::matmul(X, w_star) > 0;
auto y = matmul(X, w_star) > 0;
// Initialize random parameters
mx::array w = 1e-2 * mx::random::normal({num_features});
array w = 1e-2 * random::normal({num_features});
auto loss_fn = [&](mx::array w) {
auto logits = mx::matmul(X, w);
auto loss_fn = [&](array w) {
auto logits = matmul(X, w);
auto scale = (1.0f / num_examples);
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
};
auto grad_fn = mx::grad(loss_fn);
auto grad_fn = grad(loss_fn);
auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) {
auto grads = grad_fn(w);
w = w - learning_rate * grads;
mx::eval(w);
auto grad = grad_fn(w);
w = w - learning_rate * grad;
eval(w);
}
auto toc = timer::time();
auto loss = loss_fn(w);
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
<< throughput << " (it/s)." << std::endl;

View File

@@ -5,27 +5,27 @@
#include "mlx/mlx.h"
namespace mx = mlx::core;
using namespace mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
mx::metal::start_capture("mlx_trace.gputrace");
metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(mx::Device::gpu);
auto s3 = new_stream(mx::Device::gpu);
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
auto x = mx::add(a, a, s2);
auto y = mx::add(b, b, s3);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << mx::multiply(x, y) << std::endl;
std::cout << multiply(x, y) << std::endl;
mx::metal::stop_capture();
metal::stop_capture();
}

View File

@@ -5,11 +5,11 @@
#include "mlx/mlx.h"
namespace mx = mlx::core;
using namespace mlx::core;
void array_basics() {
// Make a scalar array:
mx::array x(1.0);
array x(1.0);
// Get the value out of it:
auto s = x.item<float>();
@@ -29,31 +29,31 @@ void array_basics() {
// The datatype should be float32:
auto dtype = x.dtype();
assert(dtype == mx::float32);
assert(dtype == float32);
// Specify the dtype when constructing the array:
x = mx::array(1, mx::int32);
assert(x.dtype() == mx::int32);
x = array(1, int32);
assert(x.dtype() == int32);
x.item<int>(); // OK
// x.item<float>(); // Undefined!
// Make a multidimensional array:
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
// mlx is row-major by default so the first row of this array
// is [1.0, 2.0] and the second row is [3.0, 4.0]
// Make an array of shape {2, 2} filled with ones:
auto y = mx::ones({2, 2});
auto y = ones({2, 2});
// Pointwise add x and y:
auto z = mx::add(x, y);
auto z = add(x, y);
// Same thing:
z = x + y;
// mlx is lazy by default. At this point `z` only
// has a shape and a type but no actual data:
assert(z.dtype() == mx::float32);
assert(z.dtype() == float32);
assert(z.shape(0) == 2);
assert(z.shape(1) == 2);
@@ -63,33 +63,33 @@ void array_basics() {
// and inputs. When `eval` is called on an array (or arrays), the array and
// all of its dependencies are recursively evaluated to produce the result.
// Once an array is evaluated, it has data and is detached from its inputs.
mx::eval(z);
eval(z);
// Of course the array can still be an input to other operations. You can
// even call eval on the array again, this will just be a no-op:
mx::eval(z); // no-op
// Of course the array can still be an input to other operations. You can even
// call eval on the array again, this will just be a no-op:
eval(z); // no-op
// Some functions or methods on arrays implicitly evaluate them. For example
// accessing a value in an array or printing the array implicitly evaluate it:
z = mx::ones({1});
z = ones({1});
z.item<float>(); // implicit evaluation
z = mx::ones({2, 2});
z = ones({2, 2});
std::cout << z << std::endl; // implicit evaluation
}
void automatic_differentiation() {
auto fn = [](mx::array x) { return mx::square(x); };
auto fn = [](array x) { return square(x); };
// Computing the derivative function of a function
auto grad_fn = mx::grad(fn);
auto grad_fn = grad(fn);
// Call grad_fn on the input to get the derivative
auto x = mx::array(1.5);
auto x = array(1.5);
auto dfdx = grad_fn(x);
// dfdx is 2 * x
// Get the second derivative by composing grad with grad
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
auto d2fdx2 = grad(grad(fn))(x);
// d2fdx2 is 2
}

View File

@@ -19,7 +19,7 @@
#include "mlx/backend/metal/utils.h"
#endif
namespace my_ext {
namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
@@ -32,24 +32,24 @@ namespace my_ext {
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
mx::array axpby(
const mx::array& x, // Input mx::array x
const mx::array& y, // Input mx::array y
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype
: promote_types(promoted_dtype, mx::float32);
: promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = mx::astype(x, out_dtype, s);
auto y_casted = mx::astype(y, out_dtype, s);
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
@@ -57,12 +57,12 @@ mx::array axpby(
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return mx::array(
return array(
/* const std::vector<int>& shape = */ out_shape,
/* mx::Dtype dtype = */ out_dtype,
/* std::unique_ptr<mx::Primitive> primitive = */
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
///////////////////////////////////////////////////////////////////////////////
@@ -71,16 +71,16 @@ mx::array axpby(
template <typename T>
void axpby_impl(
const mx::array& x,
const mx::array& y,
mx::array& out,
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
@@ -94,8 +94,8 @@ void axpby_impl(
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
@@ -105,8 +105,8 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
@@ -114,14 +114,14 @@ void Axpby::eval(
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == mx::float32) {
if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
@@ -136,9 +136,9 @@ void Axpby::eval(
template <typename T>
void axpby_impl_accelerate(
const mx::array& x,
const mx::array& y,
mx::array& out,
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
@@ -150,10 +150,10 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, mx::CopyType::Vector);
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
@@ -175,15 +175,15 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == mx::float32 &&
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
@@ -198,8 +198,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
eval(inputs, outputs);
}
@@ -213,8 +213,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
@@ -225,7 +225,7 @@ void Axpby::eval_gpu(
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = mx::metal::device(s.device);
auto& d = metal::device(s.device);
// Prepare to specialize based on contiguity
bool contiguous_kernel =
@@ -235,12 +235,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)
@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
@@ -272,15 +272,15 @@ void Axpby::eval_gpu(
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim if needed
if (!contiguous_kernel) {
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
}
// We launch 1 thread for each input and make sure that the number of
@@ -295,15 +295,15 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
#else // Metal is not available
/** Fail evaluation on GPU */
void Axpby::eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& out) {
const std::vector<array>& inputs,
std::vector<array>& out) {
throw std::runtime_error("Axpby has no GPU implementation.");
}
@@ -314,9 +314,9 @@ void Axpby::eval_gpu(
///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */
std::vector<mx::array> Axpby::jvp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
@@ -328,8 +328,8 @@ std::vector<mx::array> Axpby::jvp(
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = mx::array(scale, tangents[0].dtype());
return {mx::multiply(scale_arr, tangents[0], stream())};
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
@@ -339,24 +339,24 @@ std::vector<mx::array> Axpby::jvp(
}
/** The vector-Jacobian product. */
std::vector<mx::array> Axpby::vjp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<mx::array>&) {
const std::vector<array>&) {
// Reverse mode diff
std::vector<mx::array> vjps;
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = mx::array(scale, cotangents[0].dtype());
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
/** Vectorize primitive along given axis */
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
const std::vector<mx::array>& inputs,
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");
}
@@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace my_ext
} // namespace mlx::core

View File

@@ -5,9 +5,7 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
namespace mx = mlx::core;
namespace my_ext {
namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
// Operation
@@ -20,22 +18,22 @@ namespace my_ext {
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
mx::array axpby(
const mx::array& x, // Input array x
const mx::array& y, // Input array y
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
StreamOrDevice s = {} // Stream on which to schedule the operation
);
///////////////////////////////////////////////////////////////////////////////
// Primitive
///////////////////////////////////////////////////////////////////////////////
class Axpby : public mx::Primitive {
class Axpby : public Primitive {
public:
explicit Axpby(mx::Stream stream, float alpha, float beta)
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta) {};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
@@ -44,25 +42,23 @@ class Axpby : public mx::Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
void eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
/** The Jacobian-vector product. */
std::vector<mx::array> jvp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<mx::array> vjp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<mx::array>& outputs) override;
const std::vector<array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
@@ -70,8 +66,8 @@ class Axpby : public mx::Primitive {
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
const std::vector<mx::array>& inputs,
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
@@ -80,16 +76,14 @@ class Axpby : public mx::Primitive {
}
/** Equivalence check **/
bool is_equivalent(const mx::Primitive& other) const override;
bool is_equivalent(const Primitive& other) const override;
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs);
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
} // namespace my_ext
} // namespace mlx::core

View File

@@ -2,6 +2,7 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T>
@@ -12,8 +13,8 @@ template <typename T>
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]],
constant const size_t* x_strides [[buffer(6)]],
constant const size_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
@@ -34,14 +35,29 @@ template <typename T>
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
}
// clang-format off
#define instantiate_axpby(type_name, type) \
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \
instantiate_kernel( \
"axpby_contiguous_" #type_name, axpby_contiguous, type)
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
// clang-format on
instantiate_axpby(complex64, complex64_t);

View File

@@ -8,12 +8,14 @@
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
m.def(
"axpby",
&my_ext::axpby,
&axpby,
"x"_a,
"y"_a,
"alpha"_a,

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.21.0
mlx>=0.18.1
nanobind==2.2.0

View File

@@ -28,19 +28,10 @@ endif()
if (@MLX_BUILD_METAL@)
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
set_and_check(MLX_INCLUDE_DIRS
${MLX_INCLUDE_DIRS}
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
)
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
else()
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
endif()
endif()
set_target_properties(mlx PROPERTIES
@@ -49,4 +40,4 @@ set_target_properties(mlx PROPERTIES
)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)

View File

@@ -18,16 +18,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
if(MSVC)
# Disable some MSVC warnings to speed up compilation.
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
endif()
if(WIN32)
# Export symbols by default to behave like macOS/linux.
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
endif()
if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()

View File

@@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
}
void free(Buffer buffer) {
allocator().free(buffer);
return allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <functional>
#include <unordered_map>
#include "mlx/array.h"
#include "mlx/ops.h"
@@ -31,7 +30,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
}
array::array(
Shape shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
@@ -42,7 +41,7 @@ array::array(
std::move(inputs))) {}
std::vector<array> array::make_arrays(
std::vector<Shape> shapes,
std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
@@ -74,7 +73,11 @@ array::array(std::initializer_list<int> data, Dtype dtype)
}
/* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
array::array(
allocator::Buffer data,
std::vector<int> shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter);
}
@@ -122,7 +125,7 @@ bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph();
}
void array::set_data(allocator::Buffer buffer, Deleter d) {
void array::set_data(allocator::Buffer buffer, deleter_t d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size();
@@ -135,9 +138,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) {
void array::set_data(
allocator::Buffer buffer,
size_t data_size,
Strides strides,
std::vector<size_t> strides,
Flags flags,
Deleter d) {
deleter_t d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size;
@@ -147,7 +150,7 @@ void array::set_data(
void array::copy_shared_buffer(
const array& other,
const Strides& strides,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@@ -166,7 +169,7 @@ void array::copy_shared_buffer(const array& other) {
void array::move_shared_buffer(
array other,
const Strides& strides,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@@ -211,8 +214,6 @@ array::~array() {
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
@@ -233,13 +234,13 @@ void array::ArrayDesc::init() {
}
}
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
}
array::ArrayDesc::ArrayDesc(
Shape shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
@@ -291,14 +292,6 @@ array::ArrayDesc::~ArrayDesc() {
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
append_deletable_inputs(*top);
// Clear out possible siblings to break circular references
for (auto& s : top->siblings) {
// Set to null here to avoid descending into top-level
// array destructor for siblings
s.array_desc_ = nullptr;
}
top->siblings.clear();
}
}

View File

@@ -15,10 +15,7 @@ namespace mlx::core {
// Forward declaration
class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>;
using Strides = std::vector<int64_t>;
using deleter_t = std::function<void(allocator::Buffer)>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc
@@ -36,7 +33,7 @@ class array {
template <typename It>
array(
It data,
Shape shape,
std::vector<int> shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -52,15 +49,15 @@ class array {
template <typename T>
array(
std::initializer_list<T> data,
Shape shape,
std::vector<int> shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
Shape shape,
std::vector<int> shape,
Dtype dtype,
Deleter deleter = allocator::free);
deleter_t deleter = allocator::free);
/** Assignment to rvalue does not compile. */
array& operator=(const array& other) && = delete;
@@ -99,7 +96,7 @@ class array {
}
/** The shape of the array as a vector of integers. */
const Shape& shape() const {
const std::vector<int>& shape() const {
return array_desc_->shape;
}
@@ -108,12 +105,12 @@ class array {
*
* This function supports negative indexing and provides
* bounds checking. */
auto shape(int dim) const {
int shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim);
}
/** The strides of the array. */
const Strides& strides() const {
const std::vector<size_t>& strides() const {
return array_desc_->strides;
}
@@ -122,7 +119,7 @@ class array {
*
* This function supports negative indexing and provides
* bounds checking. */
auto strides(int dim) const {
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
}
@@ -187,13 +184,13 @@ class array {
*/
array(
Shape shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
static std::vector<array> make_arrays(
std::vector<Shape> shapes,
std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs);
@@ -210,8 +207,8 @@ class array {
struct Data {
allocator::Buffer buffer;
Deleter d;
Data(allocator::Buffer buffer, Deleter d = allocator::free)
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {}
// Not copyable
Data(const Data& d) = delete;
@@ -400,18 +397,18 @@ class array {
// Check if the array is a tracer array
bool is_tracer() const;
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
void set_data(
allocator::Buffer buffer,
size_t data_size,
Strides strides,
std::vector<size_t> strides,
Flags flags,
Deleter d = allocator::free);
deleter_t d = allocator::free);
void copy_shared_buffer(
const array& other,
const Strides& strides,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
@@ -420,7 +417,7 @@ class array {
void move_shared_buffer(
array other,
const Strides& strides,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
@@ -439,8 +436,8 @@ class array {
void init(const It src);
struct ArrayDesc {
Shape shape;
Strides strides;
std::vector<int> shape;
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::shared_ptr<Primitive> primitive;
@@ -474,10 +471,10 @@ class array {
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(Shape shape, Dtype dtype);
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
explicit ArrayDesc(
Shape shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
@@ -505,7 +502,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It>
array::array(
It data,
Shape shape,
std::vector<int> shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
init(data);
@@ -524,7 +521,7 @@ array::array(
template <typename T>
array::array(
std::initializer_list<T> data,
Shape shape,
std::vector<int> shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) {

View File

@@ -43,7 +43,6 @@ DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(ExpandDims)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
@@ -66,6 +65,7 @@ DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
@@ -76,7 +76,6 @@ DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(Squeeze)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)

View File

@@ -5,21 +5,13 @@ else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
if(MSVC)
set(SHELL_EXT ps1)
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
else()
set(SHELL_EXT sh)
set(SHELL_CMD /bin/bash)
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
DEPENDS make_compiled_preamble.${SHELL_EXT}
${PROJECT_SOURCE_DIR} ${CLANG}
DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h

View File

@@ -13,8 +13,8 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
Shape shape = in.shape();
std::vector<size_t> strides = in.strides();
std::vector<int> shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
for (uint32_t i = 0; i < out.size(); ++i) {

View File

@@ -178,10 +178,10 @@ void binary_op_dims(
const T* b,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
@@ -212,10 +212,10 @@ void binary_op_dispatch_dims(
array& out,
Op op,
int dim,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides) {
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
@@ -258,10 +258,10 @@ void binary_op_dispatch_dims(
return;
}
ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) {
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
size_t stride = out_strides[dim - 4];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
@@ -327,7 +327,7 @@ void binary_op(
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
@@ -337,7 +337,7 @@ void binary_op(
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const auto& arr_strides) {
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}

View File

@@ -16,10 +16,10 @@ void binary_op_dims(
U* out_a,
U* out_b,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
@@ -96,9 +96,9 @@ void binary_op_dispatch_dims(
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,

View File

@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway.
size_t data_size = out.size();
return move_or_copy(in, out, strides_, flags, data_size, offset_);
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
@@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
move_or_copy(in, out, strides, flags, in.data_size());
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
out.copy_shared_buffer(inputs[0]);
}
void CustomTransforms::eval(
@@ -72,7 +72,7 @@ void CustomTransforms::eval(
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
move_or_copy(inputs[j], outputs[i]);
outputs[i].copy_shared_buffer(inputs[j]);
}
}
@@ -81,20 +81,10 @@ void Depends::eval(
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
move_or_copy(inputs[i], outputs[i]);
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto strides = in.strides();
for (auto ax : axes_) {
strides.insert(strides.begin() + ax, 1);
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -151,7 +141,9 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
}
}
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
@@ -159,7 +151,8 @@ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
// Special case for scalars
if (in.ndim() == 0) {
return {false, Strides(out.ndim(), 0)};
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
@@ -167,7 +160,7 @@ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
Strides out_strides;
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
@@ -188,9 +181,9 @@ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
return {copy_necessary, out_strides};
}
void shared_buffer_reshape(
void Reshape::shared_buffer_reshape(
const array& in,
const Strides& out_strides,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
@@ -201,7 +194,7 @@ void shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
move_or_copy(in, out, out_strides, flags, in.data_size());
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Split::eval(
@@ -256,28 +249,26 @@ void Split::eval(
}
}
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
Strides strides;
for (int i = 0, j = 0; i < in.ndim(); ++i) {
if (j < axes_.size() && i == axes_[j]) {
j++;
} else {
strides.push_back(in.strides(i));
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
return std::make_tuple(data_offset, inp_strides);
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
out.copy_shared_buffer(inputs[0]);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Strides out_strides(out.ndim());
std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
@@ -294,8 +285,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
int64_t f_stride = 1;
int64_t b_stride = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
@@ -306,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri);
}
}
move_or_copy(in, out, out_strides, flags, in.data_size());
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -165,7 +165,7 @@ void compiled_allocate_outputs(
bool move_buffers /* = false */) {
if (contiguous) {
int o = 0;
Strides strides;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {

View File

@@ -47,7 +47,7 @@ bool compile_available_for_device(const Device& device) {
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name).string();
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
@@ -279,7 +279,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
bool contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;

View File

@@ -746,9 +746,9 @@ void explicit_gemm_conv_1D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
Shape strided_shape = {N, oH, wH, C};
std::vector<int> strided_shape = {N, oH, wH, C};
Strides strided_strides = {
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[1],
@@ -865,9 +865,9 @@ void explicit_gemm_conv_2D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
Strides strided_strides = {
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
@@ -974,7 +974,7 @@ void explicit_gemm_conv_ND_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
Shape strided_shape(oDim.size() + wDim.size() + 2);
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N;
for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i];
@@ -984,7 +984,7 @@ void explicit_gemm_conv_ND_cpu(
}
strided_shape.back() = C;
Strides strided_strides(in.shape().size() * 2 - 2);
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0];
for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
@@ -1000,7 +1000,7 @@ void explicit_gemm_conv_ND_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
Shape strided_reshape = {N, C};
std::vector<int> strided_reshape = {N, C};
for (const auto& o : oDim) {
strided_reshape[0] *= o;
}

View File

@@ -26,13 +26,13 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
template <typename SrcT, typename DstT, int D>
template <typename SrcT, typename DstT, typename StrideT, int D>
inline void copy_dims(
const SrcT* src,
DstT* dst,
const Shape& shape,
const Strides& i_strides,
const Strides& o_strides,
const std::vector<int>& shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
int axis) {
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
@@ -40,7 +40,7 @@ inline void copy_dims(
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
copy_dims<SrcT, DstT, D - 1>(
copy_dims<SrcT, DstT, StrideT, D - 1>(
src, dst, shape, i_strides, o_strides, axis + 1);
} else {
*dst = static_cast<DstT>(*src);
@@ -50,13 +50,13 @@ inline void copy_dims(
}
}
template <typename SrcT, typename DstT>
template <typename SrcT, typename DstT, typename StrideT>
void copy_general_general(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
int64_t i_offset,
int64_t o_offset) {
if (data_shape.empty()) {
@@ -65,30 +65,30 @@ void copy_general_general(
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size();
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
copy_dims<SrcT, DstT, StrideT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
copy_dims<SrcT, DstT, StrideT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
}
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, 3>(
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
StrideT stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
for (StrideT elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
@@ -102,37 +102,37 @@ void copy_general_general(
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>(
copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT>
template <typename SrcT, typename DstT, typename StrideT>
void copy_general(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides&,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>&,
int64_t i_offset,
int64_t o_offset) {
copy_general_general<SrcT, DstT>(
copy_general_general<SrcT, DstT, StrideT>(
src,
dst,
data_shape,
i_strides,
make_contiguous_strides(data_shape),
make_contiguous_strides<StrideT>(data_shape),
i_offset,
o_offset);
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>(
copy_general_general<SrcT, DstT, size_t>(
src,
dst,
src.shape(),
src.strides(),
make_contiguous_strides(src.shape()),
make_contiguous_strides<size_t>(src.shape()),
0,
0);
}
@@ -282,12 +282,13 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype);
}
template <typename StrideT>
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
@@ -310,4 +311,24 @@ void copy_inplace(
}
}
template void copy_inplace<size_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core

View File

@@ -26,12 +26,13 @@ enum class CopyType {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);

View File

@@ -57,7 +57,6 @@ DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(ExpandDims)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
@@ -87,6 +86,7 @@ DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
@@ -101,7 +101,6 @@ DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Squeeze)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)
@@ -131,7 +130,7 @@ inline void matmul_common_general(
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
stx = arr.shape(-1);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};

View File

@@ -32,7 +32,7 @@ void gather(
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& slice_sizes) {
const std::vector<int>& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
@@ -80,10 +80,11 @@ void gather(
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator src_it;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
if (!can_copy && src.ndim() > 0) {
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
}
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
@@ -118,7 +119,7 @@ void dispatch_gather(
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& size) {
const std::vector<int>& size) {
switch (out.dtype()) {
case bool_:
gather<bool, IdxT>(src, inds, out, axes, size);
@@ -222,16 +223,16 @@ void scatter(
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
Shape update_shape(
std::vector<int> update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;

View File

@@ -2,15 +2,6 @@
#pragma once
// Required for Visual Studio.
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
#ifdef _MSC_VER
#include <complex>
#define LAPACK_COMPLEX_CUSTOM
#define lapack_complex_float std::complex<float>
#define lapack_complex_double std::complex<double>
#endif
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else

View File

@@ -1,38 +0,0 @@
# This script generates a C++ function that provides the CPU
# code for use with kernel generation.
#
# Copyright © 2024 Apple Inc.
$OUTPUT_FILE = $args[0]
$CL = $args[1]
$SRCDIR = $args[2]
# Get command result as array.
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
# Remove empty lines.
# Otherwise there will be too much empty lines making the result unreadable.
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
# Concatenate to string.
$CONTENT = $CONTENT -join '`n'
# Append extra content.
$CONTENT = @"
$($CONTENT)
using namespace mlx::core;
using namespace mlx::core::detail;
"@
# Convert each char to ASCII code.
# Unlike the unix script that outputs string literal directly, the output from
# MSVC is way too large to be embedded as string and compilation will fail, so
# we store it as static array instead.
$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'
$OUTPUT = @"
const char* get_kernel_preamble() {
static char preamble[] = { $CHARCODES };
return preamble;
}
"@
Set-Content -Path $OUTPUT_FILE -Value $OUTPUT

View File

@@ -10,16 +10,15 @@ OUTPUT_FILE=$1
GCC=$2
SRCDIR=$3
CLANG=$4
ARCH=$5
if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
EOM
CC_FLAGS="-arch ${ARCH}"
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi

View File

@@ -19,10 +19,10 @@ inline void mask_matrix(
int block_size,
const int X,
const int Y,
const int64_t X_data_str,
const int64_t Y_data_str,
const int64_t X_mask_str,
const int64_t Y_mask_str,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str,
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
@@ -84,7 +84,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
int64_t stx = arr.shape(-1);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
@@ -117,13 +117,13 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
int Y,
size_t X_data_str,
size_t Y_data_str) {
auto mask_offset = elem_to_loc(
size_t mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
auto X_mask_str = mask.strides()[mask.ndim() - 2];
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
if (mask.dtype() == bool_) {
return mask_matrix(
@@ -230,7 +230,7 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
int64_t stx = arr.shape(-1);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
@@ -262,13 +262,13 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
auto batch_shape = get_batch_dims(out.shape());
std::vector<int> batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides());
auto batch_shape_B = get_batch_dims(b.shape());
auto batch_strides_B = get_batch_dims(b.strides());
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();

View File

@@ -500,12 +500,7 @@ struct Equal {
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
if constexpr (std::is_integral_v<T>) {
// isnan always returns false for integers, and MSVC refuses to compile.
return x == y;
} else {
return x == y || (std::isnan(x) && std::isnan(y));
}
return x == y || (std::isnan(x) && std::isnan(y));
}
};

View File

@@ -19,16 +19,6 @@
namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -169,17 +159,6 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
}
}
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -268,14 +247,6 @@ void Expm1::eval(const std::vector<array>& inputs, array& out) {
}
}
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -435,8 +406,18 @@ void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
void Round::eval(const std::vector<array>& inputs, array& out) {
@@ -506,17 +487,34 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices_, strides_);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
@@ -541,11 +539,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
auto [data_offset, out_strides] = prepare_slice(out);
// Do copy
Strides upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace(
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
@@ -608,7 +606,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}

View File

@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r) {
// Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false;
auto strides = in.strides();
std::vector<size_t> strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(

View File

@@ -2,38 +2,13 @@
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
w_out[2] =
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
}
}
template <typename T, int bits, int group_size>
void _qmm(
T* result,
@@ -45,12 +20,13 @@ void _qmm(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
@@ -64,25 +40,13 @@ void _qmm(
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) += xi * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
uint32_t wi = *w_local++;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
@@ -103,12 +67,13 @@ void _qmm_t(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
@@ -120,26 +85,12 @@ void _qmm_t(
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += x_local[p] * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
x_local += pack_factor;
uint32_t wi = *w_local++;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum +=
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
@@ -151,55 +102,6 @@ void _qmm_t(
}
}
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
template <typename T, int bits>
void _qmm_dispatch_group(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
bool transposed_w) {
switch (group_size) {
case 32:
_qmm_dispatch_transpose<T, bits, 32>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 64:
_qmm_dispatch_transpose<T, bits, 64>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 128:
_qmm_dispatch_transpose<T, bits, 128>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
default:
throw std::invalid_argument(
"Quantization group size must be 32, 64 or 128.");
}
}
template <typename T>
void _qmm_dispatch_typed(
T* result,
@@ -214,29 +116,79 @@ void _qmm_dispatch_typed(
int bits,
bool transposed_w) {
switch (bits) {
case 2:
_qmm_dispatch_group<T, 2>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 3:
_qmm_dispatch_group<T, 3>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 4:
_qmm_dispatch_group<T, 4>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 6:
_qmm_dispatch_group<T, 6>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 8:
_qmm_dispatch_group<T, 8>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
default:
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
case 2: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
}
std::ostringstream msg;
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
}
void _qmm_dispatch(
@@ -452,114 +404,4 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
transpose_);
}
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
int bits,
int group_size) {
const T* w = w_.data<T>();
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
}
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
}
} // namespace mlx::core

View File

@@ -120,53 +120,45 @@ struct MinReduce {
};
template <typename InT>
void reduce_dispatch_and_or(
void reduce_dispatch_out(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
} else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
}
}
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
switch (rtype) {
case Reduce::And: {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
break;
}
} else {
auto op = [](auto y, auto x) { (*y) *= x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, op);
} else {
case Reduce::Or: {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
break;
}
case Reduce::Sum: {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if (out.dtype() == int32) {
// special case since the input type can be bool
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
}
break;
}
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
reduction_op<InT, InT>(in, out, axes, 1, op);
break;
}
case Reduce::Max: {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
break;
}
case Reduce::Min: {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
break;
}
}
}
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
}
}
@@ -174,19 +166,19 @@ void reduce_dispatch_min_max(
void nd_loop(
std::function<void(int)> callback,
const Shape& shape,
const Strides& strides) {
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
auto size = shape[dim];
auto stride = strides[dim];
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
auto size = shape[dim];
auto stride = strides[dim];
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
@@ -198,114 +190,46 @@ void nd_loop(
void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
switch (in.dtype()) {
case bool_:
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
break;
}
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
case uint8:
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
break;
}
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
case uint16:
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
break;
}
}
}

View File

@@ -38,10 +38,13 @@ enum ReductionOpType {
struct ReductionPlan {
ReductionOpType type;
Shape shape;
Strides strides;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
@@ -52,10 +55,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Should this be in utils?
void nd_loop(
std::function<void(int)> callback,
const Shape& shape,
const Strides& strides);
const std::vector<int>& shape,
const std::vector<size_t>& strides);
std::pair<Shape, Strides> shapes_without_reduction_axes(
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
@@ -110,6 +113,9 @@ void reduction_op(
return;
}
std::vector<int> shape;
std::vector<size_t> strides;
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
@@ -129,7 +135,7 @@ void reduction_op(
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
@@ -175,7 +181,7 @@ void reduction_op(
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
@@ -205,7 +211,7 @@ void reduction_op(
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;

View File

@@ -4,11 +4,11 @@
namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes(
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
std::vector<int> shape = x.shape();
std::vector<size_t> strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
@@ -29,8 +29,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
@@ -69,7 +69,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, int64_t>> reductions;
std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) {
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
@@ -93,8 +93,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
}
}
Shape shape;
Strides strides;
std::vector<int> shape;
std::vector<size_t> strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
@@ -109,15 +109,15 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int64_t size = 1;
int size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
auto stride_i = x.strides()[i];
auto shape_i = x.shape(i);
size_t stride_i = x.strides()[i];
int shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;

View File

@@ -4,22 +4,24 @@
namespace mlx::core {
std::tuple<int64_t, Strides> prepare_slice(
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
const Shape& start_indices,
const Shape& strides) {
const std::vector<int>& start_indices,
const std::vector<int>& strides) {
int64_t data_offset = 0;
Strides inp_strides(in.ndim(), 0);
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
}
return std::make_tuple(data_offset, inp_strides);
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const Strides& out_strides,
const std::vector<size_t>& out_strides,
size_t data_offset,
size_t data_size,
array& out) {
@@ -32,7 +34,7 @@ void shared_buffer_slice(
flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size);
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
} // namespace mlx::core

View File

@@ -6,14 +6,14 @@
namespace mlx::core {
std::tuple<int64_t, Strides> prepare_slice(
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
const Shape& start_indices,
const Shape& strides);
const std::vector<int>& start_indices,
const std::vector<int>& strides);
void shared_buffer_slice(
const array& in,
const Strides& out_strides,
const std::vector<size_t>& out_strides,
size_t data_offset,
size_t data_size,
array& out);

View File

@@ -25,7 +25,7 @@ struct StridedIterator {
// Constructors
StridedIterator() = default;
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
@@ -99,7 +99,7 @@ struct StridedIterator {
}
private:
int64_t stride_;
size_t stride_;
T* ptr_;
};
@@ -120,11 +120,11 @@ void sort(const array& in, array& out, int axis) {
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = out.strides()[axis];
auto axis_size = out.shape(axis);
size_t axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
// Perform sorting in place
ContiguousIterator src_it(
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
@@ -158,14 +158,14 @@ void argsort(const array& in, array& out, int axis) {
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
auto in_stride = in.strides()[axis];
auto out_stride = out.strides()[axis];
auto axis_size = in.shape(axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
// Perform sorting
ContiguousIterator in_it(
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
@@ -208,13 +208,13 @@ void partition(const array& in, array& out, int axis, int kth) {
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = in.strides()[axis];
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator src_it(
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
@@ -249,16 +249,16 @@ void argpartition(const array& in, array& out, int axis, int kth) {
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
auto in_stride = in.strides()[axis];
auto out_stride = out.strides()[axis];
auto axis_size = in.shape(axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition
ContiguousIterator in_it(
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;

View File

@@ -78,11 +78,11 @@ void ternary_op_dims(
const T3* c,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& c_strides,
const Strides& out_strides,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& c_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
@@ -164,10 +164,10 @@ void ternary_op_dispatch_dims(
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator c_it(shape, c_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,

View File

@@ -4,35 +4,15 @@
namespace mlx::core {
void move_or_copy(const array& in, array& out) {
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.copy_shared_buffer(in);
}
}
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
if (in.is_donatable()) {
out.move_shared_buffer(in, strides, flags, data_size, offset);
} else {
out.copy_shared_buffer(in, strides, flags, data_size, offset);
}
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,
int64_t size_cap) {
template <typename StrideT>
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<std::vector<StrideT>>& strides,
StrideT size_cap) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
Shape to_collapse;
std::vector<int> to_collapse;
if (shape.size() > 0) {
if (shape[0] != 1) {
to_collapse.push_back(0);
@@ -41,7 +21,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
for (const auto& st : strides) {
for (const std::vector<StrideT>& st : strides) {
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
contiguous = false;
size = shape[i];
@@ -58,8 +38,8 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
to_collapse.push_back(-1);
}
Shape out_shape;
std::vector<Strides> out_strides(strides.size());
std::vector<int> out_shape;
std::vector<std::vector<StrideT>> out_strides(strides.size());
for (int i = 0;;) {
while (i < to_collapse.size() && to_collapse[i] == -1) {
++i;
@@ -74,7 +54,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const auto& st = strides[j];
const std::vector<StrideT>& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}
i = k + 1;
@@ -89,12 +69,29 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
return std::make_tuple(out_shape, out_strides);
}
std::pair<Shape, Strides> collapse_contiguous_dims(
const Shape& shape,
const Strides& strides,
int64_t size_cap) {
Shape collapsed_shape;
Strides collapsed_strides;
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
template <typename StrideT>
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
StrideT size_cap) {
std::vector<int> collapsed_shape;
std::vector<StrideT> collapsed_strides;
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
@@ -104,7 +101,7 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
continue;
} else if (
strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<int64_t>(shape[i]) > size_cap) {
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
@@ -117,10 +114,25 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
return std::make_pair(collapsed_shape, collapsed_strides);
}
std::pair<Shape, Strides> collapse_contiguous_dims(
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
int64_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims_impl<size_t>(
a.shape(), a.strides(), size_cap);
}
} // namespace mlx::core

View File

@@ -8,9 +8,12 @@
namespace mlx::core {
inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int64_t loc = 0;
template <typename StrideT>
inline StrideT elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<StrideT>& strides) {
StrideT loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
@@ -19,15 +22,16 @@ elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
return loc;
}
inline int64_t elem_to_loc(int elem, const array& a) {
inline size_t elem_to_loc(int elem, const array& a) {
if (a.flags().row_contiguous) {
return elem;
}
return elem_to_loc(elem, a.shape(), a.strides());
}
inline Strides make_contiguous_strides(const Shape& shape) {
Strides strides(shape.size(), 1);
template <typename StrideT>
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<StrideT> strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i];
}
@@ -40,15 +44,22 @@ inline Strides make_contiguous_strides(const Shape& shape) {
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<array>& xs,
size_t size_cap = std::numeric_limits<int32_t>::max()) {
std::vector<Strides> strides;
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
@@ -62,14 +73,19 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
}
// The single array version of the above.
std::pair<Shape, Strides> collapse_contiguous_dims(
const Shape& shape,
const Strides& strides,
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<Shape, Strides> collapse_contiguous_dims(
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
int64_t size_cap = std::numeric_limits<int32_t>::max());
size_t size_cap = std::numeric_limits<int32_t>::max());
template <typename StrideT>
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();
@@ -86,7 +102,7 @@ struct ContiguousIterator {
loc += strides_[i];
}
void seek(int64_t n) {
void seek(StrideT n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]);
@@ -112,29 +128,32 @@ struct ContiguousIterator {
}
explicit ContiguousIterator(
const Shape& shape,
const Strides& strides,
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
int dims)
: shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = Shape(shape_.size(), 0);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
int64_t loc{0};
StrideT loc{0};
private:
Shape shape_;
Strides strides_;
Shape pos_;
std::vector<int> shape_;
std::vector<StrideT> strides_;
std::vector<int> pos_;
};
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
template <typename StrideT>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<StrideT>& strides) {
size_t no_broadcast_data_size = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
size_t f_stride = 1;
size_t b_stride = 1;
bool is_row_contiguous = true;
bool is_col_contiguous = true;
@@ -159,19 +178,4 @@ inline bool is_donatable(const array& in, const array& out) {
in.buffer_size() <= out.nbytes() + donation_extra;
}
void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
} // namespace mlx::core

View File

@@ -14,21 +14,14 @@ function(make_jit_source SRC_FILE)
COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
${SRC_FILE}
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME})
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
endfunction(make_jit_source)
make_jit_source(
utils
kernels/jit/bf16.h
kernels/metal_3_0/bf16.h
kernels/metal_3_1/bf16.h
kernels/bf16_math.h
kernels/complex.h
kernels/defines.h)
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)

View File

@@ -30,7 +30,7 @@ BufferCache::BufferCache(MTL::Device* device)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool();
auto thread_pool = metal::new_scoped_memory_pool();
clear();
}
@@ -155,13 +155,11 @@ MetalAllocator::MetalAllocator()
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
std::unique_lock lk(mutex_);
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
@@ -171,7 +169,6 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
};
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_);
residency_set_.resize(wired_limit_);
return limit;
@@ -208,7 +205,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr};
}
auto pool = metal::new_scoped_memory_pool();
auto thread_pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache
@@ -229,7 +226,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
auto pool = metal::new_scoped_memory_pool();
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
@@ -240,15 +237,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
void MetalAllocator::clear_cache() {
std::unique_lock lk(mutex_);
auto pool = metal::new_scoped_memory_pool();
buffer_cache_.clear();
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
residency_set_.erase(buf);
active_memory_ -= buf->length();
@@ -256,7 +249,7 @@ void MetalAllocator::free(Buffer buffer) {
buffer_cache_.recycle_to_cache(buf);
} else {
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
auto thread_pool = metal::new_scoped_memory_pool();
buf->release();
}
}

View File

@@ -22,37 +22,37 @@ std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool large,
bool use_2d,
int ndim,
int work_per_thread) {
std::string kname;
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname = "ss";
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname = (large ? "sv2" : "sv");
kname << (use_2d ? "sv2" : "sv");
break;
case BinaryOpType::VectorScalar:
kname = (large ? "vs2" : "vs");
kname << (use_2d ? "vs2" : "vs");
break;
case BinaryOpType::VectorVector:
kname = (large ? "vv2" : "vv");
kname << (use_2d ? "vv2" : "vv");
break;
case BinaryOpType::General:
kname = "g";
kname << "g";
if (ndim <= 3) {
kname += std::to_string(ndim);
kname << ndim;
} else {
concatenate(kname, "n", std::to_string(work_per_thread));
}
if (large) {
kname += "large";
kname << "n";
if (work_per_thread > 1) {
kname << work_per_thread;
}
}
break;
}
concatenate(kname, "_", op, type_to_name(a));
return kname;
kname << "_" << op << type_to_name(a);
return kname.str();
}
void binary_op_gpu_inplace(
@@ -75,32 +75,24 @@ void binary_op_gpu_inplace(
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
} else {
decltype(a.strides()) e{};
return std::make_tuple(decltype(a.shape()){}, e, e, e);
std::vector<size_t> e;
return std::make_tuple(std::vector<int>{}, e, e, e);
}
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool large;
bool use_2d = out.data_size() > UINT32_MAX;
auto ndim = shape.size();
int work_per_thread;
if (bopt == BinaryOpType::General) {
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out.size() > INT32_MAX;
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = 1;
}
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
auto& d = metal::device(s.device);
auto kernel = outputs.size() == 2
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
@@ -125,15 +117,19 @@ void binary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) {
compute_encoder.set_vector_bytes(shape, arg_idx++);
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
compute_encoder.set_bytes<int>(ndim, arg_idx++);
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
}
if (thread_group_size != 1024) {
@@ -141,7 +137,7 @@ void binary_op_gpu_inplace(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
@@ -149,9 +145,9 @@ void binary_op_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream> //TODO
#include <sstream>
#include "mlx/backend/common/compiled.h"
@@ -12,12 +11,12 @@
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core {
constexpr int WORK_PER_THREAD = 4;
inline void build_kernel(
std::string& os,
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
@@ -42,8 +41,8 @@ inline void build_kernel(
int cnt = 0;
// Start the kernel
os += fmt::format(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
os << "[[host_name(\"" << kernel_name << "\")]]\n"
<< "[[kernel]] void " << kernel_name << "(\n";
// Add the input arguments
for (auto& x : inputs) {
@@ -55,61 +54,51 @@ inline void build_kernel(
}
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
}
os += fmt::format(
" device const {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
xname,
cnt++);
}
if (add_indices) {
os += fmt::format(
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
os << " constant const size_t* in_strides [[buffer(" << cnt++
<< ")]],\n";
}
// Add the output arguments
for (auto& x : outputs) {
os += fmt::format(
" device {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
namer.get_name(x),
cnt++);
os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os += fmt::format(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]],\n"
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
}
if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
}
// The thread index in the whole grid
os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n";
os << " uint3 pos [[thread_position_in_grid]],\n"
<< " uint3 grid [[threads_per_grid]]) {\n";
std::string idx_type = use_big_index ? "int64_t" : "uint";
if (contiguous && use_big_index) {
if (use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
os += fmt::format(
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
<< " int xshape = output_shape["
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
} else {
os += fmt::format(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
}
// Read constant / contiguous inputs in tmps
@@ -120,19 +109,16 @@ inline void build_kernel(
if (is_constant(x)) {
auto type_str = get_type_string(x.dtype());
std::ostringstream ss;
print_constant(ss, x);
os += fmt::format(
" auto tmp_{0} = static_cast<{1}>({2});\n",
xname,
get_type_string(x.dtype()),
ss.str());
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ");\n";
} else if (is_scalar(x)) {
os += fmt::format(
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];\n";
} else if (contiguous) {
os += fmt::format(
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];\n";
} else {
nc_inputs.push_back(x);
}
@@ -141,96 +127,83 @@ inline void build_kernel(
// Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) {
int offset = i * ndim;
os +=
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
<< "in_strides[" << offset << "]);\n";
} else if (ndim == 2) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
idx_type,
offset);
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
<< "in_strides + " << offset << ");\n";
} else if (ndim == 3) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
<< "in_strides + " << offset << ");\n";
} else if (!dynamic_dims) {
int offset = (i + 1) * ndim;
os += fmt::format(
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
idx_type,
offset - 1,
offset - 2);
int offset = i * ndim;
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
<< offset + ndim - 1 << "]"
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
} else {
os += fmt::format(
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
idx_type,
i);
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
<< i << " + ndim - 1]"
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
}
}
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
os += " uint zpos = pos.z;\n";
os << " uint zpos = pos.z;\n";
if (dynamic_dims) {
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
} else {
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
}
os += " uint l = zpos % output_shape[d];\n";
os << " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" index_{0} += ", xname);
os << " index_" << xname << " += ";
if (dynamic_dims) {
os +=
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
os << "l * in_strides[" << i << " * ndim + d];\n";
} else {
os +=
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
os << "l * in_strides[" << i * ndim << " + d];\n";
}
}
os += " zpos /= output_shape[d];\n }\n";
os << " zpos /= output_shape[d];\n }\n";
}
// Open per-thread loop
if (work_per_thread > 1) {
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
// Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
os += fmt::format(
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index_" << xname << "];\n";
}
// Actually write the computation
for (auto& x : tape) {
os += fmt::format(
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os += fmt::format(
"static_cast<{0}>(tmp_{1});\n",
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");\n";
} else {
std::ostringstream ss;
x.primitive().print(ss);
os += ss.str();
os += "()(";
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";\n";
}
// Increment indices and close per thread loop
if (work_per_thread > 1) {
@@ -238,18 +211,18 @@ inline void build_kernel(
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
if (!dynamic_dims) {
os += fmt::format(
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
os << " index_" << xname << " += "
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
} else {
os += fmt::format(
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
os << " index_" << xname << " += "
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
}
}
os += " index++;\n }\n";
os << " index++;\n }\n";
}
// Finish the kernel
os += "}\n";
os << "}\n";
if (cnt > 31) {
std::ostringstream msg;
@@ -273,9 +246,9 @@ void Compiled::eval_gpu(
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() {
std::string kernel = metal::utils();
concatenate(
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
std::ostringstream kernel;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
@@ -288,7 +261,7 @@ void Compiled::eval_gpu(
/* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_large",
kernel_lib_ + "_contiguous_big",
inputs_,
outputs_,
tape_,
@@ -309,21 +282,7 @@ void Compiled::eval_gpu(
/* ndim = */ i,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ i > 3 ? 2 : 1);
if (i > 1) {
build_kernel(
kernel,
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false,
/* use_big_index = */ true,
/* work_per_thread = */ i > 3 ? 4 : 1);
}
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
}
build_kernel(
kernel,
@@ -336,32 +295,20 @@ void Compiled::eval_gpu(
/* ndim = */ 0,
/* dynamic_dims = */ true,
/* use_big_index = */ false,
/* work_per_thread = */ 2);
build_kernel(
kernel,
kernel_lib_ + "_strided_dynamic_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true,
/* use_big_index = */ true,
/* work_per_thread = */ 4);
return kernel;
/* work_per_thread = */ WORK_PER_THREAD);
return kernel.str();
});
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, output_shape);
bool contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
std::vector<Strides> initial_strides;
std::vector<std::vector<size_t>> initial_strides;
initial_strides.push_back(outputs[0].strides());
Shape shape;
std::vector<Strides> strides;
std::vector<int> shape;
std::vector<std::vector<size_t>> strides;
if (!contiguous) {
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
@@ -376,7 +323,7 @@ void Compiled::eval_gpu(
}
// Broadcast the inputs to the output shape.
Strides xstrides;
std::vector<size_t> xstrides;
int j = 0;
for (; j < output_shape.size() - x.ndim(); j++) {
if (output_shape[j] == 1) {
@@ -402,19 +349,13 @@ void Compiled::eval_gpu(
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool large;
bool use_2d = false;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
large = (max_size > UINT32_MAX);
} else {
size_t max_size = 0;
for (auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
large = (max_size > UINT32_MAX);
use_2d = (max_size > UINT32_MAX);
}
// Get the kernel from the lib
@@ -427,18 +368,17 @@ void Compiled::eval_gpu(
} else {
kernel_name += std::to_string(shape.size());
}
}
if (large) {
kernel_name += "_large";
} else if (use_2d) {
kernel_name += "_big";
}
auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
Strides in_strides;
std::vector<size_t> in_strides;
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
@@ -454,7 +394,8 @@ void Compiled::eval_gpu(
}
}
if (!in_strides.empty()) {
compute_encoder.set_vector_bytes(in_strides, cnt++);
compute_encoder->setBytes(
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
}
compiled_allocate_outputs(
@@ -467,13 +408,14 @@ void Compiled::eval_gpu(
// Put the output shape and strides in
if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++);
compute_encoder->setBytes(
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
}
// Put the number of dims in if it is dynamic
if (dynamic) {
compute_encoder.set_bytes(ndim, cnt++);
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
}
// Launch the kernel
@@ -482,15 +424,15 @@ void Compiled::eval_gpu(
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = large
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2;
@@ -503,7 +445,7 @@ void Compiled::eval_gpu(
}
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -44,28 +44,27 @@ void explicit_gemm_conv_ND_gpu(
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder.set_bytes(conv_params, 2);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64);
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x;
int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Reshape weight
Shape wt_reshape{implicit_K, implicit_N};
Strides wt_restride{1, implicit_K};
std::vector<int> wt_reshape{implicit_K, implicit_N};
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
auto wt_flags = wt.flags();
wt_flags.row_contiguous = false;
@@ -123,31 +122,33 @@ void explicit_gemm_conv_group_ND_gpu(
<< N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder.set_bytes(conv_params, 2);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64);
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x;
int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
array wt_view(
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
wt,
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
wt.flags(),
wt.size());
// Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
@@ -236,7 +237,7 @@ void slow_conv_2D_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
@@ -251,8 +252,8 @@ void slow_conv_2D_gpu(
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
@@ -351,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
wn,
n_channel_specialization,
small_filter);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
@@ -367,11 +368,11 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.set_bytes(gemm_params, 4);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
@@ -505,7 +506,7 @@ void implicit_gemm_conv_2D_general_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
@@ -522,15 +523,17 @@ void implicit_gemm_conv_2D_general_gpu(
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.set_bytes(gemm_params, 4);
compute_encoder.set_bytes(jump_params, 5);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
compute_encoder.set_vector_bytes(base_h, 6);
compute_encoder.set_vector_bytes(base_w, 7);
compute_encoder->setBytes(
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@@ -619,18 +622,18 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1);
compute_encoder.set_bytes(C_c, 2);
compute_encoder.set_bytes(O_c, 3);
compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3);
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do input transform
@@ -647,17 +650,18 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1);
compute_encoder.set_bytes(conv_params_updated, 2);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do batched gemm
@@ -694,17 +698,18 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(conv_params_updated, 2);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}

View File

@@ -43,12 +43,13 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& data_shape,
const Strides& strides_in_pre,
const Strides& strides_out_pre,
const std::vector<int>& data_shape,
const std::vector<stride_t>& strides_in_pre,
const std::vector<stride_t>& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
@@ -67,52 +68,50 @@ void copy_gpu_inplace(
/* size_cap = */ INT32_MAX);
return std::make_tuple(shape, strides[0], strides[1]);
} else {
Strides e{};
return std::make_tuple(Shape{}, e, e);
std::vector<stride_t> e;
return std::make_tuple(std::vector<int>{}, e, e);
}
};
auto [shape, strides_in_, strides_out_] = maybe_collapse();
int ndim = shape.size();
bool large;
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
// Allow for negative strides
large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
} else {
large = out.data_size() > UINT32_MAX;
}
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
int work_per_thread = 1;
std::string kernel_name;
switch (ctype) {
case CopyType::Scalar:
kernel_name = (large ? "s2" : "s");
break;
case CopyType::Vector:
kernel_name = (large ? "v2" : "v");
break;
case CopyType::General:
kernel_name = "g";
break;
case CopyType::GeneralGeneral:
kernel_name = "gg";
break;
}
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kernel_name += std::to_string(shape.size());
} else {
work_per_thread = large ? 4 : 2;
concatenate(kernel_name, "n", std::to_string(work_per_thread));
{
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << (use_2d ? "s2" : "s");
break;
case CopyType::Vector:
kname << (use_2d ? "v2" : "v");
break;
case CopyType::General:
kname << "g";
break;
case CopyType::GeneralGeneral:
kname << "gg";
break;
}
if (large) {
kernel_name += "large";
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
work_per_thread = 4;
kname << "n4";
}
}
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
inp_offset *= size_of(in.dtype());
@@ -123,26 +122,26 @@ void copy_gpu_inplace(
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
Strides strides_in{strides_in_.begin(), strides_in_.end()};
Strides strides_out{strides_out_.begin(), strides_out_.end()};
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) {
compute_encoder.set_vector_bytes(shape, ndim, 2);
set_vector_bytes(compute_encoder, shape, ndim, 2);
}
compute_encoder.set_vector_bytes(strides_in, ndim, 3);
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
}
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
size_t rest = data_size / (dim0 * dim1);
int rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder.set_bytes(ndim, 5);
compute_encoder->setBytes(&ndim, sizeof(int), 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
}
@@ -153,16 +152,16 @@ void copy_gpu_inplace(
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -179,13 +178,14 @@ void copy_gpu_inplace(
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& istride,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
return copy_gpu_inplace(
in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s);
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
}
void fill_gpu(const array& val, array& out, const Stream& s) {
@@ -193,13 +193,13 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool large = out.data_size() > UINT32_MAX;
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out);
auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
@@ -210,9 +210,9 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -8,12 +8,13 @@
namespace mlx::core {
// Generic copy inplace
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype,
@@ -31,7 +32,7 @@ void copy_gpu_inplace(
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& istride,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s);

View File

@@ -43,7 +43,7 @@ void CustomKernel::eval_gpu(
d.get_library(lib_name, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
@@ -53,15 +53,15 @@ void CustomKernel::eval_gpu(
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
compute_encoder.set_bytes(ndim, index);
compute_encoder->setBytes(&ndim, sizeof(int), index);
index++;
}
}
@@ -72,11 +72,10 @@ void CustomKernel::eval_gpu(
}
const auto [tx, ty, tz] = threadgroup_;
MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_;
MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}

View File

@@ -23,18 +23,14 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH;
auto get_metal_version() {
auto get_metal_version_ = []() {
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
return MTL::LanguageVersion3_2;
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
return MTL::LanguageVersion3_1;
} else {
return MTL::LanguageVersion3_0;
}
};
static auto metal_version_ = get_metal_version_();
return metal_version_;
constexpr auto get_metal_version() {
#if (MLX_METAL_VERSION >= 320)
return MTL::LanguageVersion3_2;
#elif (MLX_METAL_VERSION >= 310)
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;
#endif
}
auto load_device() {
@@ -175,14 +171,14 @@ void CommandEncoder::maybeInsertBarrier() {
next_outputs_.clear();
}
void CommandEncoder::dispatch_threadgroups(
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims);
}
void CommandEncoder::dispatch_threads(
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
@@ -280,7 +276,7 @@ void Device::end_encoding(int index) {
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unnecessary waits
// from the map to limit the growth of the map and avoid unecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
@@ -302,7 +298,7 @@ void Device::end_encoding(int index) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc.wait_for_fence(it->second->fence);
enc->waitForFence(it->second->fence);
waiting_on.insert(it->second);
}
}
@@ -311,7 +307,7 @@ void Device::end_encoding(int index) {
stream.outputs[out] = stream.fence;
}
}
enc.update_fence(stream.fence->fence);
enc->updateFence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
@@ -645,27 +641,21 @@ void new_stream(Stream stream) {
std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0;
size_t length = sizeof(memsize);
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctl(mib, 2, &memsize, &length, NULL, 0);
sysctl(mib, 2, &memsize, &length, NULL, 0);
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}};
};
static auto device_info_ = init_device_info();
return device_info_;
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}};
}
} // namespace mlx::core::metal

View File

@@ -58,43 +58,16 @@ struct CommandEncoder {
CommandEncoder& enc;
};
MTL::ComputeCommandEncoder* operator->() {
return enc_;
}
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier();
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
enc_->setComputePipelineState(kernel);
}
void wait_for_fence(MTL::Fence* fence) {
enc_->waitForFence(fence);
}
void update_fence(MTL::Fence* fence) {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
template <typename T>
void set_bytes(const T* v, int n, int idx) {
return enc_->setBytes(v, n * sizeof(T), idx);
}
template <typename T>
void set_bytes(const T& v, int idx) {
return enc_->setBytes(&v, sizeof(T), idx);
}
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}

View File

@@ -363,7 +363,7 @@ void multi_upload_bluestein_fft(
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
// Broadcast w_q and w_k to the batch size
Strides b_strides(in.ndim(), 0);
std::vector<size_t> b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
@@ -386,8 +386,8 @@ void multi_upload_bluestein_fft(
copies.push_back(slice_temp);
copies.push_back(conj_temp);
Shape rstarts(in.ndim(), 0);
Shape rstrides(in.ndim(), 1);
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
rstarts[axis] = in.shape(axis) - back_offset;
rstrides[axis] = -1;
unary_op_gpu({in}, conj_temp, "Conjugate", s);
@@ -431,19 +431,19 @@ void multi_upload_bluestein_fft(
s);
int offset = plan.bluestein_n - (2 * n - 1);
Shape starts(in.ndim(), 0);
Shape strides(in.ndim(), 1);
std::vector<int> starts(in.ndim(), 0);
std::vector<int> strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) {
Shape rstarts(in.ndim(), 0);
Shape rstrides(in.ndim(), 1);
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
slice_gpu(temp1, out, rstarts, strides, s);
} else if (real && inverse) {
Strides b_strides(in.ndim(), 0);
std::vector<size_t> b_strides(in.ndim(), 0);
auto inv_n = array({1.0f / n}, {1}, float32);
array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float);
@@ -531,8 +531,8 @@ void fft_op(
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
Strides strides;
int64_t cur_stride = x.shape(axis);
std::vector<size_t> strides;
size_t cur_stride = x.shape(axis);
for (int a = 0; a < x.ndim(); a++) {
if (a == axis) {
strides.push_back(1);
@@ -699,7 +699,7 @@ void fft_op(
auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1);
@@ -711,9 +711,9 @@ void fft_op(
compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder.set_bytes(n, 4);
compute_encoder.set_bytes(plan.bluestein_n, 5);
compute_encoder.set_bytes(total_batch_size, 6);
compute_encoder->setBytes(&n, sizeof(int), 4);
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
} else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q);
@@ -723,22 +723,22 @@ void fft_op(
compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder.set_bytes(n, 5);
compute_encoder.set_bytes(total_batch_size, 6);
compute_encoder.set_bytes(plan.rader_n, 7);
compute_encoder->setBytes(&n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
} else if (four_step_params.required) {
compute_encoder.set_bytes(four_step_params.n1, 2);
compute_encoder.set_bytes(four_step_params.n2, 3);
compute_encoder.set_bytes(total_batch_size, 4);
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
} else {
compute_encoder.set_bytes(n, 2);
compute_encoder.set_bytes(total_batch_size, 3);
compute_encoder->setBytes(&n, sizeof(int), 2);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
}
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
@@ -777,7 +777,7 @@ void nd_fft_op(
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);

View File

@@ -137,14 +137,14 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(scale, 2);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
};
if (m > 1) {

View File

@@ -53,31 +53,27 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
bool large_index = nidx && inputs[1].size() > INT32_MAX;
bool large_src = src.size() > INT32_MAX;
bool large_out = out.size() > INT32_MAX;
bool large = large_index || large_src || large_out;
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string kernel_name = fmt::format(
"gather{0}{1}_{2}_{3}_{4}",
type_to_name(out),
idx_type_name,
nidx,
idx_ndim,
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
{
std::ostringstream kname;
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
<< "_" << idx_ndim;
lib_name = kname.str();
kernel_name = lib_name;
}
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::gather();
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations
kernel_source += fmt::format(
kernel_source << fmt::format(
gather_kernels,
type_to_name(out) + idx_type_name,
out_type_str,
@@ -85,14 +81,13 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx,
idx_args,
idx_arr,
idx_ndim,
large ? "int64_t" : "int");
return kernel_source;
idx_ndim);
return kernel_source.str();
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1;
for (auto s : slice_sizes_) {
@@ -136,20 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 1);
// Set source info
compute_encoder.set_vector_bytes(src.shape(), 2);
compute_encoder.set_vector_bytes(src.strides(), 3);
compute_encoder.set_bytes(ndim, 4);
compute_encoder.set_vector_bytes(slice_sizes_, 5);
compute_encoder.set_vector_bytes(axes_, 6);
set_vector_bytes(compute_encoder, src.shape(), 2);
set_vector_bytes(compute_encoder, src.strides(), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
set_vector_bytes(compute_encoder, slice_sizes_, 5);
set_vector_bytes(compute_encoder, axes_, 6);
// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder.set_vector_bytes(idx_shapes, 7);
compute_encoder.set_vector_bytes(idx_strides, 8);
compute_encoder.set_vector_bytes(idx_contigs, 9);
compute_encoder.set_bytes(idx_ndim, 10);
set_vector_bytes(compute_encoder, idx_shapes, 7);
set_vector_bytes(compute_encoder, idx_strides, 8);
set_vector_bytes(compute_encoder, idx_contigs, 9);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
@@ -157,7 +152,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Launch grid
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -214,6 +209,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nwork = 32;
}
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name;
switch (reduce_type_) {
@@ -234,24 +231,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
auto upd_contig = upd.flags().row_contiguous;
bool large_out = out.size() > INT32_MAX;
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
bool large_upd = upd.size() > INT32_MAX;
bool large = large_out || large_idx || large_upd;
std::string kernel_name = fmt::format(
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
type_to_name(out),
idx_type_name,
op_name,
nidx,
upd_contig ? "updc_true" : "updc_false",
nwork,
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
{
std::ostringstream kname;
kname << "scatter" << type_to_name(out) << idx_type_name;
kname << "_" << op_name << "_" << nidx << "_"
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
lib_name = kname.str();
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
@@ -279,7 +270,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source += fmt::format(
kernel_source << fmt::format(
scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str,
@@ -289,9 +280,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args,
idx_arr,
upd_contig,
nwork,
large ? "int64_t" : "int");
return kernel_source;
nwork);
return kernel_source.str();
});
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -299,7 +289,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t nthreads = upd.size();
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder->setComputePipelineState(kernel);
// Set all the buffers
compute_encoder.set_input_array(upd, 1);
@@ -312,8 +302,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
upd_size *= upd.shape(i);
}
// Collect all idx shapes and strides into one place
Shape idx_shapes;
Strides idx_strides;
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
// To access .data() use char instead of bool
// bool is 1 byte in Metal so this is safe
std::vector<char> idx_contigs;
@@ -332,30 +322,30 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 3);
compute_encoder.set_bytes(stride_, 4);
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder.set_vector_bytes(upd.shape(), 3);
compute_encoder.set_vector_bytes(upd.strides(), 4);
set_vector_bytes(compute_encoder, upd.shape(), 3);
set_vector_bytes(compute_encoder, upd.strides(), 4);
}
compute_encoder.set_bytes(upd_ndim, 5);
compute_encoder.set_bytes(upd_size, 6);
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 7);
compute_encoder.set_bytes(stride_, 8);
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder.set_vector_bytes(out.shape(), 7);
compute_encoder.set_vector_bytes(out.strides(), 8);
set_vector_bytes(compute_encoder, out.shape(), 7);
set_vector_bytes(compute_encoder, out.strides(), 8);
}
compute_encoder.set_bytes(out_ndim, 9);
compute_encoder.set_vector_bytes(axes_, 10);
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
@@ -365,11 +355,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_strides.push_back(0);
idx_contigs.push_back(false);
}
compute_encoder.set_vector_bytes(idx_shapes, 11);
compute_encoder.set_vector_bytes(idx_strides, 12);
compute_encoder.set_vector_bytes(idx_contigs, 13);
compute_encoder.set_bytes(idx_ndim, 14);
compute_encoder.set_bytes(idx_size, 15);
set_vector_bytes(compute_encoder, idx_shapes, 11);
set_vector_bytes(compute_encoder, idx_strides, 12);
set_vector_bytes(compute_encoder, idx_contigs, 13);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
@@ -385,7 +375,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
}
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -11,13 +11,13 @@ gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device {outm_t}* out_mask [[buffer(20)]],
const device {opm_t}* mat_mask [[buffer(21)]],
const device {opm_t}* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant int64_t* mask_batch_strides [[buffer(24)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],

View File

@@ -1,16 +1,16 @@
// Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}_{7}(
[[kernel]] void gather{0}_{3}_{6}(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant int64_t* src_strides [[buffer(3)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant int64_t* idx_strides [[buffer(8)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant bool* idx_contigs [[buffer(9)]],
const constant int& idx_ndim [[buffer(10)]],
{4}
@@ -19,7 +19,7 @@ constexpr std::string_view gather_kernels = R"(
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}, {7}>(
return gather_impl<{1}, {2}, {3}, {6}>(
src,
out,
src_shape,
@@ -34,19 +34,19 @@ constexpr std::string_view gather_kernels = R"(
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant int64_t* upd_strides [[buffer(4)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant int64_t* out_strides [[buffer(8)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant int64_t* idx_strides [[buffer(12)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant bool* idx_contigs [[buffer(13)]],
const constant int& idx_ndim [[buffer(14)]],
const constant size_t& idx_size [[buffer(15)]],
@@ -54,7 +54,7 @@ constexpr std::string_view scatter_kernels = R"(
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
updates,
out,
upd_shape,

View File

@@ -10,12 +10,12 @@ template [[host_name("{name}")]]
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
@@ -43,7 +43,7 @@ block_masked_gemm<
device {itype}* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant size_t* batch_strides [[buffer(7)]],
const device {outmasktype}* out_mask [[buffer(10)]],
const device {opmasktype}* lhs_mask [[buffer(11)]],
const device {opmasktype}* rhs_mask [[buffer(12)]],

View File

@@ -1,4 +1,5 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
@@ -45,27 +46,25 @@ MTL::ComputePipelineState* get_unary_kernel(
auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source << get_template_definition(
"v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
void append_binary_kernels(
void add_binary_kernels(
const std::string lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
std::string& kernel_source) {
std::ostringstream& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
@@ -74,27 +73,27 @@ void append_binary_kernels(
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1large", "binary_g_nd1"},
{"g2large", "binary_g_nd2"},
{"g3large", "binary_g_nd3"},
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
}};
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
for (auto& [name, func] : kernel_types) {
kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
std::string template_def;
template_def = get_template_definition(
name + "_" + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
}
kernel_source += get_template_definition(
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
kernel_source << get_template_definition(
"gn4_" + lib_name,
"binary_g",
get_type_string(in_type),
get_type_string(out_type),
op,
4);
}
MTL::ComputePipelineState* get_binary_kernel(
@@ -105,11 +104,10 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
kernel_source = metal::utils();
concatenate(kernel_source, metal::binary_ops(), metal::binary());
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -122,10 +120,11 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::binary_ops(), metal::binary_two());
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -137,31 +136,24 @@ MTL::ComputePipelineState* get_ternary_kernel(
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
std::ostringstream kernel_source;
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
}};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto& [name, func] : kernel_types) {
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
std::string template_def;
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
}
kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
return kernel_source;
kernel_source << get_template_definition(
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -173,47 +165,31 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::copy();
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source +=
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"g1large_" + lib_name, "copy_g_nd1", in_type, out_type);
kernel_source += get_template_definition(
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
kernel_source += get_template_definition(
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
kernel_source += get_template_definition(
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
kernel_source += get_template_definition(
"gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
kernel_source += get_template_definition(
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
kernel_source += get_template_definition(
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source;
kernel_source << metal::utils() << metal::copy()
<< get_template_definition(
"s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition(
"v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -345,17 +321,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const Dtype& out_type) {
const array& out) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_t + ">";
std::string kernel_source = metal::utils();
kernel_source += metal::reduce_utils();
kernel_source += metal::reduce();
kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
return kernel_source;
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, func_name, out_type, op);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -365,31 +341,30 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const Dtype& in_type,
const Dtype& out_type,
const std::string& idx_t,
const array& in,
const array& out,
int ndim /* = -1 */,
int bm /* = -1 */,
int bn /* = -1 */) {
auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_t + ">";
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
if (bm >= 0) {
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
} else if (ndim >= 0) {
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim);
} else {
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t);
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op);
}
return kernel_source;
return kernel_source.str();
});
auto st = d.get_kernel(kernel_name, lib);
return st;

View File

@@ -81,16 +81,15 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const Dtype& out_type);
const array& out);
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const Dtype& in_type,
const Dtype& out_type,
const std::string& idx_t,
const array& in,
const array& out,
int ndim = -1,
int bm = -1,
int bn = -1);

View File

@@ -1,27 +1,13 @@
set(BASE_HEADERS
metal_3_1/bf16.h
metal_3_0/bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h)
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
else()
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air"
@@ -44,7 +30,9 @@ build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention sdpa_vector.h)
build_kernel(
scaled_dot_product_attention scaled_dot_product_attention_params.h
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
set(STEEL_HEADERS
steel/defines.h
@@ -66,24 +54,6 @@ set(STEEL_HEADERS
steel/utils/type_traits.h
steel/utils/integral_constant.h)
set(STEEL_ATTN_HEADERS
steel/defines.h
steel/utils.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/utils/type_traits.h
steel/utils/integral_constant.h
steel/attn/attn.h
steel/attn/loader.h
steel/attn/mma.h
steel/attn/params.h
steel/attn/transforms.h
steel/attn/kernels/steel_attention.h)
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h)

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \

View File

@@ -75,10 +75,10 @@ template <typename T, typename Op, int N_READS = 4>
const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]],
const constant int* shape [[buffer(2)]],
const constant int64_t* in_strides [[buffer(3)]],
const constant int64_t* out_strides [[buffer(4)]],
const constant size_t* in_strides [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& ndim [[buffer(5)]],
const constant int64_t& axis_stride [[buffer(6)]],
const constant size_t& axis_stride [[buffer(6)]],
const constant size_t& axis_size [[buffer(7)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],

View File

@@ -6,6 +6,12 @@
using namespace metal;
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;
#else
/////////////////////////////////////////////////////////////////////////////
// Helpers
/////////////////////////////////////////////////////////////////////////////
@@ -305,10 +311,7 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
} // namespace metal
#pragma METAL internals : disable
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return x.bits_;
}
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
}
#endif
#include "mlx/backend/metal/kernels/bf16_math.h"

View File

@@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/metal/kernels/bf16.h"
///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16
///////////////////////////////////////////////////////////////////////////////
@@ -367,6 +369,18 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal {
instantiate_metal_simd_comm_funcs(

View File

@@ -43,7 +43,7 @@ template <typename T, typename U, typename Op>
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[0], b[offset]);
}
@@ -54,7 +54,7 @@ template <typename T, typename U, typename Op>
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[0]);
}
@@ -65,75 +65,72 @@ template <typename T, typename U, typename Op>
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const int64_t& a_stride,
constant const int64_t& b_stride,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const int64_t a_strides[2],
constant const int64_t b_strides[2],
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const int64_t a_strides[3],
constant const int64_t b_strides[3],
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = int64_t>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<IdxT>(
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;

View File

@@ -9,22 +9,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@@ -56,7 +56,7 @@ template <typename T, typename U, typename Op>
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
@@ -70,7 +70,7 @@ template <typename T, typename U, typename Op>
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
@@ -84,87 +84,84 @@ template <typename T, typename U, typename Op>
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int64_t& a_stride,
constant const int64_t& b_stride,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int64_t a_strides[2],
constant const int64_t b_strides[2],
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int64_t a_strides[3],
constant const int64_t b_strides[3],
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = int64_t>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int* shape,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<IdxT>(
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];

View File

@@ -7,22 +7,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \

View File

@@ -4,8 +4,8 @@
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const

View File

@@ -22,7 +22,7 @@ template <typename T, typename U>
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[0]);
}
@@ -32,46 +32,46 @@ template <typename T, typename U>
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
IdxT dst_idx =
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
template <typename T, typename U, int N = 1>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -80,16 +80,17 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc<IdxT>(
auto src_idx = elem_to_loc(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
IdxT dst_idx =
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
int64_t dst_idx =
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
return;
}
auto xshape = src_shape[ndim - 1];
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
int64_t dst_idx =
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
@@ -97,43 +98,43 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
}
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, typename IdxT = int64_t>
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
template <typename T, typename U, int N = 1>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -142,7 +143,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto idx = elem_to_loc_2_nd<IdxT>(
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z},
src_shape,
src_strides,
@@ -152,8 +153,8 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
dst[idx.y] = static_cast<U>(src[idx.x]);
return;
}
IdxT src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1];
auto src_xstride = src_strides[ndim - 1];
auto dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]);

View File

@@ -2,29 +2,22 @@
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -4,30 +4,30 @@
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant int64_t* src_strides [[buffer(3)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
LocT src_idx = 0;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
LocT idx_loc;
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
idx_loc = index.x * indices.strides[indices.ndim * i];
} else {
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc += indices.row_contiguous[i]
? index.y
: elem_to_loc<LocT>(
: elem_to_loc(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
@@ -35,17 +35,17 @@ METAL_FUNC void gather_impl(
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset =
elem_to_loc<LocT>(index.z, slice_sizes, src_strides, src_ndim);
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
LocT out_idx = index.z;
size_t out_idx = index.z;
if (IDX_NDIM == 1) {
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
} else if (IDX_NDIM >= 2) {
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
out_idx +=
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
}
out[out_idx] = src[src_offset + src_idx];
}

View File

@@ -3,6 +3,8 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
@@ -436,9 +438,9 @@ template <
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const constant int64_t* bias_batch_stride [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
@@ -486,21 +488,31 @@ template <
simd_lid);
}
#define instantiate_gemv_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
instantiate_kernel( \
"gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc "_axpby" #axpby, \
gemv, \
itype, \
bm, \
bn, \
sm, \
sn, \
tm, \
tn, \
nc, \
axpby)
#define instantiate_gemv_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
"_tm" #tm "_tn" #tn "_nc" #nc \
"_axpby" #axpby)]] [[kernel]] void \
gemv<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
@@ -539,13 +551,13 @@ template <
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* index_batch_strides [[buffer(11)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant int64_t* vector_batch_stride [[buffer(14)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant int64_t* matrix_batch_stride [[buffer(17)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -561,8 +573,8 @@ template <
// Update batch offsets
if (batch_ndim > 1) {
const constant auto* veci_bstrides = index_batch_strides;
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
@@ -609,14 +621,37 @@ template <
simd_lid);
}
// clang-format off
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
instantiate_kernel( \
"gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn, \
gemv_gather, itype, bm, bn, sm, sn, tm, tn)
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
gemv_gather<itype, bm, bn, sm, sn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_bs_blocks(name, itype) \
// clang-format off
#define instantiate_gemv_bs_blocks(name, itype) \
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
@@ -651,9 +686,9 @@ template <
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const constant int64_t* bias_batch_stride [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
@@ -701,14 +736,33 @@ template <
simd_lid);
}
// clang-format off
#define instantiate_gemv_t_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
instantiate_kernel( \
"gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
"_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \
gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)
#define instantiate_gemv_t_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
"_tm" #tm "_tn" #tn "_nc" #nc \
"_axpby" #axpby)]] [[kernel]] void \
gemv_t<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
@@ -748,13 +802,13 @@ template <
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* index_batch_strides [[buffer(11)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant int64_t* vector_batch_stride [[buffer(14)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant int64_t* matrix_batch_stride [[buffer(17)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -770,8 +824,8 @@ template <
// Update batch offsets
if (batch_ndim > 1) {
const constant auto* veci_bstrides = index_batch_strides;
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
@@ -818,14 +872,36 @@ template <
simd_lid);
}
// clang-format off
#define instantiate_gemv_t_bs_helper( \
nm, itype, bm, bn, sm, sn, tm, tn) \
instantiate_kernel( \
"gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn, \
gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
gemv_t_gather<itype, bm, bn, sm, sn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t_bs_blocks(name, itype) \
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \
@@ -836,4 +912,4 @@ template <
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -642,13 +642,13 @@ template <
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant int64_t* mask_batch_strides [[buffer(24)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -673,8 +673,8 @@ template <
}
if (has_operand_mask) {
const constant auto* mask_strides_mat = mask_batch_strides;
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
@@ -742,13 +742,13 @@ template <
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant int64_t* mask_batch_strides [[buffer(24)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -773,8 +773,8 @@ template <
}
if (has_operand_mask) {
const constant auto* mask_strides_mat = mask_batch_strides;
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);

View File

@@ -4,17 +4,37 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/gemv_masked.h"
#define instantiate_gemv_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_kernel( \
"gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc, \
gemv_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
template [[host_name("gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
gemv_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const device outm_t* out_mask [[buffer(20)]], \
const device opm_t* mat_mask [[buffer(21)]], \
const device opm_t* vec_mask [[buffer(22)]], \
const constant int* mask_strides [[buffer(23)]], \
const constant size_t* mask_batch_strides [[buffer(24)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
@@ -43,11 +63,29 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t);
#define instantiate_gemv_t_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_kernel( \
"gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc, \
gemv_t_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
gemv_t_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const device outm_t* out_mask [[buffer(20)]], \
const device opm_t* mat_mask [[buffer(21)]], \
const device opm_t* vec_mask [[buffer(22)]], \
const constant int* mask_strides [[buffer(23)]], \
const constant size_t* mask_batch_strides [[buffer(24)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \

View File

@@ -8,13 +8,13 @@ template <typename IdxT, int NIDX>
struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant int64_t* strides;
const constant size_t* strides;
const constant bool* row_contiguous;
const int ndim;
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {

View File

@@ -1,16 +0,0 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#define jit_if #if
#define jit_else #else
#define jit_endif #endif
jit_if (__METAL_VERSION__ >= 310)
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
jit_else
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
jit_endif // clang-format on

View File

@@ -3,6 +3,8 @@
#include <metal_common>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;

View File

@@ -1,16 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
typedef bfloat bfloat16_t;
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return as_type<uint16_t>(x);
}
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return as_type<bfloat16_t>(x);
}

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More