Compare commits

..

56 Commits

Author SHA1 Message Date
Angelos Katharopoulos
c02e14c264 Add the 3bit packed qmm_t 2024-12-17 22:16:30 -08:00
Angelos Katharopoulos
d75a509234 Add 3bit packed quants 2024-12-17 10:49:13 -08:00
Angelos Katharopoulos
14420949d2 Fix the optional in gather_qmm python binding 2024-12-16 22:14:19 -08:00
Angelos Katharopoulos
4847199ec6 Add the quantization type option to quantizable layers 2024-12-16 22:11:23 -08:00
Angelos Katharopoulos
fb7be036af Add packed_affine_qmm_t 2024-12-16 21:49:14 -08:00
Angelos Katharopoulos
410ccdbed5 Change the argument name to quantization_type 2024-12-16 13:32:20 -08:00
Angelos Katharopoulos
f5da489a3c Add some error reporting 2024-12-16 13:22:05 -08:00
Angelos Katharopoulos
c2e6d58441 Revert the change in packing order 2024-12-16 13:20:01 -08:00
Angelos Katharopoulos
17a1fa2f0b Improve the benchmark 2024-12-14 23:04:29 -08:00
Angelos Katharopoulos
fd161aa31f Change order in weight packing 2024-12-14 22:51:41 -08:00
Angelos Katharopoulos
bf6dc54110 Add the 2 bit vectorized reads 2024-12-14 21:19:02 -08:00
Angelos Katharopoulos
d7ed624502 Vectorized reads 2024-12-14 15:36:34 -08:00
Angelos Katharopoulos
05cb54ae3f Another packing 2024-12-13 23:48:25 -08:00
Angelos Katharopoulos
cb358dbdda Revert "Attempt different packing"
This reverts commit e4b587819c.
2024-12-13 23:23:41 -08:00
Angelos Katharopoulos
e4b587819c Attempt different packing 2024-12-13 18:36:36 -08:00
Angelos Katharopoulos
a06c968f4d Add a small benchmark 2024-12-13 16:29:11 -08:00
Angelos Katharopoulos
651c510940 Working packed qmv 2024-12-13 16:29:11 -08:00
Angelos Katharopoulos
11ec07ff9d Initial python binding 2024-12-13 16:29:11 -08:00
Angelos Katharopoulos
bdd68bd893 Add a quantization type in the ops 2024-12-13 16:29:11 -08:00
Awni Hannun
50f3535693 Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places

* few more

* few more

* fix
2024-12-12 17:00:44 -08:00
Awni Hannun
9111999af3 Fix small sort with metal validation (#1695) 2024-12-12 09:21:45 -08:00
Awni Hannun
6bd28d246e Allow no copy negative strides in as_strided and slice (#1688)
* allow no copy negative strides in as_strided and slice

* fix jit

* fix jit
2024-12-12 08:59:45 -08:00
Cheng
4d595a2a39 Make compiled preamble work in MSVC (#1675)
* Make compiled preamble work in MSVC

* Remove logging

* Only use powershell for MSVC
2024-12-12 08:55:49 -08:00
Awni Hannun
3a21f61772 Fix build (#1693) 2024-12-11 23:56:25 -08:00
Awni Hannun
4e1e9520e1 Flatten and unflatten (#1692)
* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
2024-12-11 21:51:37 -08:00
Cheng
0bf19037ca Remove "using namespace mlx::core" in python/src (#1689) 2024-12-11 15:45:39 -08:00
Awni Hannun
f3dfa36a3a Fix x86 tests (#1691)
* fix x86 tests

* comment
2024-12-11 07:47:18 -08:00
Cheng
4f9b60dd53 Remove "using namespace mlx::core" in benchmarks/examples (#1685)
* Remove "using namespace mlx::core" in benchmarks/examples

* Fix building example extension

* A missing one in comment

* Fix building on M chips
2024-12-11 07:08:29 -08:00
Awni Hannun
f76a49e555 ExpandDims primitive (#1687)
* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
2024-12-10 16:39:07 -08:00
Cheng
310ad8d9db Build OpenBLAS from source code for MSVC (#1674)
* Download OpenBLAS binaries when building with MSVC

* Download dlfcn-win32

* Link with dlfcn-win32 correctly

* Build OpenBLAS from source code

* Link with openblas statically

* Link with BLAS privately
2024-12-10 16:14:44 -08:00
Cheng
56db268f47 Provide a pread implementation for MSVC (#1666) 2024-12-10 15:55:53 -08:00
Cheng
92ab6bdeb8 Fix shared library not exporting symbols on Windows (#1684)
* Fix shared library not exporting symbols on Windows

* Function name style
2024-12-10 13:59:14 -08:00
Cheng
0070e360a1 Disable MSVC warnings (#1680) 2024-12-09 19:41:14 -08:00
Amethyst Shen
9df8fed046 Metal-cpp version bump (#1668)
* Metal-cpp version bump

Apple has released the stable version of Metal-cpp for macOS 15 and iOS 18. CMakeLists.txt is updated to build with it instead of the beta one.

* Fix style with cmake-format
2024-12-09 19:40:35 -08:00
Cheng
a59fae040f Fix library output directory for MSVC (#1681) 2024-12-09 19:07:50 -08:00
Awni Hannun
29a620cab2 No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
2024-12-09 18:57:38 -08:00
Cheng
87d7a2520e Use Py_ssize_t in python bindings (#1678)
* Use Py_ssize_t in python bindings

* Args passed to std::max must be same type
2024-12-09 12:59:19 -08:00
Awni Hannun
40c62c1321 Use int64 stride everywhere (#1671)
* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
2024-12-09 11:09:02 -08:00
Awni Hannun
35b412c099 Fix compile hasher for string constants. (#1677)
* fix hash

* add test

* nit
2024-12-09 09:26:18 -08:00
Cheng
d0f471cff7 Using math defines requires switch in MSVC (#1665)
* Using math defines requires switch in MSVC

* Fix more math macros

* Fix type

* Remove _MSC_VER guard for math defines
2024-12-08 08:16:28 -08:00
Cheng
6f316b8bf5 Use int64_t instead of ssize_t (#1673) 2024-12-07 20:10:44 -08:00
Cheng
7c10c93a1f Convert filesystem path to std::string explicitly (#1672) 2024-12-07 20:10:06 -08:00
Cheng
d92ea094f1 Use && instead of and (#1663)
* Use && instead of and

* Remove "and" in ops.cpp
2024-12-07 18:26:39 -08:00
Cheng
6ae5423b4a Do not pass integers to isnan (#1664) 2024-12-07 18:26:23 -08:00
Cheng
9635cffdc8 Include io.h in MSVC for IO functions (#1661) 2024-12-07 18:26:06 -08:00
Cheng
96986fb362 Use auto* for pointers (#1662) 2024-12-07 18:25:40 -08:00
Cheng
3ceb341a75 Use correct complex type for MSVC (#1660) 2024-12-07 18:25:22 -08:00
Awni Hannun
50fa705125 patch bump (#1656) 2024-12-06 13:16:19 -08:00
Awni Hannun
69a2991614 allow compiling lambdas in C++ (#1650)
* allow compiling lambdas in C++

* fix test

* more tests

* auto detect capture-less lambda
2024-12-06 13:13:21 -08:00
mt_caret
fd3377dd1f Support bias correction in Adam and AdamW optimizers (#1640) 2024-12-06 12:13:34 -08:00
Awni Hannun
d0b6cb0425 More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape

* more shape

* fix

* fix
2024-12-06 11:29:18 -08:00
Alex Barron
95c4a2e3af add back conditionaltype (#1655) 2024-12-06 11:12:01 -08:00
Awni Hannun
bc2a29f033 fix (#1654) 2024-12-06 10:48:58 -08:00
Nripesh Niketan
3bb5b4a302 Chore: Add default language in pre-commit and bump hooks (#1652) 2024-12-06 07:54:29 -08:00
Awni Hannun
fc88fd9097 Shape and Strides 1 / N (#1645)
* shape and stride type def

* more shape
2024-12-05 12:53:43 -08:00
Awni Hannun
c5b0928c1f fix fallback (#1646) 2024-12-05 11:59:53 -08:00
181 changed files with 6180 additions and 5039 deletions

3
.gitignore vendored
View File

@@ -76,6 +76,9 @@ build/
*.out
*.app
# Debug symbols
*.pdb
# VSCode
.vscode/
.DS_Store

View File

@@ -1,13 +1,14 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
rev: v19.1.4
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
rev: 24.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:

View File

@@ -20,11 +20,12 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.21.0)
set(MLX_VERSION 0.21.1)
endif()
# --------------------- Processor tests -------------------------
@@ -93,8 +94,7 @@ elseif(MLX_BUILD_METAL)
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
@@ -113,16 +113,52 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
set(MLX_BUILD_GGUF OFF)
# There is no prebuilt OpenBLAS distribution for MSVC.
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
endif()
# Windows implementation of dlfcn.h APIs.
FetchContent_Declare(
dlfcn-win32
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(dlfcn-win32)
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
target_link_libraries(mlx PRIVATE dl)
endif()
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code.
FetchContent_Declare(
openblas
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
GIT_TAG v0.3.28
EXCLUDE_FROM_ALL)
set(BUILD_STATIC_LIBS ON) # link statically
set(NOFORTRAN ON) # msvc has no fortran compiler
FetchContent_MakeAvailable(openblas)
target_link_libraries(mlx PRIVATE openblas)
target_include_directories(
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
else()
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
@@ -140,7 +176,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
@@ -153,14 +189,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
if(WIN32)
find_package(dlfcn-win32 REQUIRED)
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
endif()
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
endif()
else()
set(MLX_BUILD_ACCELERATE OFF)

View File

@@ -5,35 +5,35 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_value_and_grad() {
auto x = ones({200, 1000});
eval(x);
auto fn = [](array x) {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
for (int i = 0; i < 20; ++i) {
x = log(exp(x));
x = mx::log(mx::exp(x));
}
return sum(x);
return mx::sum(x);
};
auto grad_fn = grad(fn);
auto grad_fn = mx::grad(fn);
auto independent_value_and_grad = [&]() {
auto value = fn(x);
auto dfdx = grad_fn(x);
return std::vector<array>{value, dfdx};
return std::vector<mx::array>{value, dfdx};
};
TIME(independent_value_and_grad);
auto value_and_grad_fn = value_and_grad(fn);
auto value_and_grad_fn = mx::value_and_grad(fn);
auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<array>{value, dfdx};
return std::vector<mx::array>{value, dfdx};
};
TIME(combined_value_and_grad);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_value_and_grad();
}

View File

@@ -4,21 +4,21 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_add_op() {
std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back());
}
set_default_device(Device::cpu);
set_default_device(mx::Device::cpu);
for (auto size : sizes) {
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
std::cout << "Size " << size << std::endl;
TIMEM("cpu", add, a, b, Device::cpu);
TIMEM("gpu", add, a, b, Device::gpu);
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
}
}

View File

@@ -6,105 +6,105 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_irregular_binary_ops_1D() {
auto device = default_device();
auto device = mx::default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", add, a, b, device);
TIMEM("1D strided", mx::add, a, b, device);
}
void time_irregular_binary_ops_2D() {
auto device = default_device();
auto device = mx::default_device();
int size = 2048;
auto a = random::uniform({size, size});
auto b = random::uniform({size, size});
eval(a, b);
TIMEM("2D regular", add, a, b, device);
auto a = mx::random::uniform({size, size});
auto b = mx::random::uniform({size, size});
mx::eval(a, b);
TIMEM("2D regular", mx::add, a, b, device);
b = transpose(b);
eval(b);
TIMEM("2D transpose", add, a, b, device);
b = mx::transpose(b);
mx::eval(b);
TIMEM("2D mx::transpose", mx::add, a, b, device);
b = random::uniform({size});
eval(b);
TIMEM("2D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({size});
mx::eval(b);
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
b = reshape(b, {size, 1});
eval(b);
TIMEM("2D broadcast dim 1", add, a, b, device);
b = mx::reshape(b, {size, 1});
mx::eval(b);
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
}
void time_irregular_binary_ops_3D() {
auto device = default_device();
auto device = mx::default_device();
int d0 = 32;
int d1 = 512;
int d2 = 512;
auto a = random::uniform({d0, d1, d2});
auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", add, a, b, device);
auto a = mx::random::uniform({d0, d1, d2});
auto b = mx::random::uniform({d0, d1, d2});
TIMEM("3D regular", mx::add, a, b, device);
b = transpose(b, {0, 2, 1});
TIMEM("3D transpose", add, a, b, device);
b = mx::transpose(b, {0, 2, 1});
TIMEM("3D mx::transpose", mx::add, a, b, device);
b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", add, a, b, device);
b = mx::random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", add, a, b, device);
b = mx::random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
b = mx::random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
b = mx::random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
b = mx::random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
}
void time_irregular_binary_ops_4D() {
auto device = default_device();
auto device = mx::default_device();
std::vector<int> shape = {8, 8, 512, 512};
auto a = random::uniform(shape);
auto b = random::uniform(shape);
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
TIMEM("4D regular", add, a, b, device);
TIMEM("4D regular", mx::add, a, b, device);
b = transpose(b, {0, 1, 3, 2});
TIMEM("4D transpose", add, a, b, device);
b = mx::transpose(b, {0, 1, 3, 2});
TIMEM("4D mx::transpose", mx::add, a, b, device);
std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1;
b = random::uniform(shape);
b = mx::random::uniform(shape);
std::ostringstream msg;
msg << om << i;
TIMEM(msg.str(), add, a, b, device);
TIMEM(msg.str(), mx::add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1;
std::ostringstream msg;
msg << om << i << ", " << j;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1;
std::ostringstream msg;
msg << om << i << ", " << j << ", " << k;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
shape[k] = a.shape(k);
}
}
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
}
void time_irregular_reshape() {
auto device = default_device();
auto device = mx::default_device();
std::vector<int> shape;
auto reshape_fn = [&shape, device](const array& a) {
return reshape(a, shape, device);
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
};
int size = 64;
int d = 2 * size;
auto a = random::uniform({d, d, d});
auto a = mx::random::uniform({d, d, d});
shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a);
a = transpose(a);
a = mx::transpose(a);
shape = {8 * size, size, size};
TIMEM("3D transpose", reshape_fn, a);
TIMEM("3D mx::transpose", reshape_fn, a);
a = transpose(a, {1, 2, 0});
a = mx::transpose(a, {1, 2, 0});
shape = {8 * size, size, size};
TIMEM("3D transpose dims 1 2", reshape_fn, a);
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, d}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a);
a = broadcast_to(random::uniform({d}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
}
void time_irregular_astype_1D() {
auto device = default_device();
auto device = mx::default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
auto a = mx::random::uniform({size});
a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", astype, a, int32, device);
TIMEM("1D strided", mx::astype, a, mx::int32, device);
}
void time_irregular_astype_2D() {
auto device = default_device();
auto device = mx::default_device();
int size = 2048;
std::vector<int> shape = {size, size};
auto a = random::uniform(shape);
TIMEM("2D regular", astype, a, int32, device);
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);
a = transpose(a);
TIMEM("2D transpose", astype, a, int32, device);
a = mx::transpose(a);
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
}
int main(int argc, char** argv) {
if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? Device::gpu : Device::cpu);
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
}
std::cout << "Benchmarks for " << default_device() << std::endl;
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D();

View File

@@ -3,20 +3,20 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_creation_ops() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto full_fp32 = [&]() { return full(shape, 3.3f); };
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
TIME(full_fp32);
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
TIME(zeros_fp32);
auto ones_fp32 = [&]() { return ones(shape, float32); };
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
TIME(ones_fp32);
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32);
}
@@ -24,194 +24,196 @@ void time_type_conversions() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto device = default_device();
auto device = mx::default_device();
auto a = zeros(shape, float32);
eval(a);
TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("float32 to uint32", astype, a, uint32, device);
auto a = mx::zeros(shape, mx::float32);
mx::eval(a);
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
a = zeros(shape, int32);
eval(a);
TIMEM("int32 to float32", astype, a, float32, device);
a = mx::zeros(shape, mx::int32);
mx::eval(a);
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
a = zeros(shape, bool_);
eval(a);
TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to uint32", astype, a, uint32, device);
a = mx::zeros(shape, mx::bool_);
mx::eval(a);
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
}
void time_random_generation() {
int M = 2000;
int N = 500;
auto uniform = [&]() { return random::uniform({M, N}, float32); };
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
TIME(uniform);
auto normal = [&]() { return random::normal({M, N}, float32); };
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
TIME(normal);
}
void time_unary_ops() {
int M = 2000;
int N = 500;
auto device = default_device();
auto device = mx::default_device();
auto a = random::normal({M, N});
eval(a);
auto a = mx::random::normal({M, N});
mx::eval(a);
TIME(mlx::core::abs, a, device);
TIME(negative, a, device);
TIME(sign, a, device);
TIME(square, a, device);
TIME(mx::negative, a, device);
TIME(mx::sign, a, device);
TIME(mx::square, a, device);
TIME(mlx::core::sqrt, a, device);
TIME(rsqrt, a, device);
TIME(mx::rsqrt, a, device);
TIME(mlx::core::exp, a, device);
a = random::uniform({M, N});
a = mx::random::uniform({M, N});
TIME(mlx::core::log, a, device);
}
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
auto condition = mx::random::randint(0, 2, {M, N, K});
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(add, a, b, device);
TIME(subtract, a, b, device);
TIME(multiply, a, b, device);
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
TIME(mx::add, a, b, device);
TIME(mx::subtract, a, b, device);
TIME(mx::multiply, a, b, device);
TIME(mx::divide, a, b, device);
TIME(mx::maximum, a, b, device);
TIME(mx::minimum, a, b, device);
TIME(mx::where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = mx::array({true});
b = mx::random::uniform({1});
mx::eval(b);
TIMEM("scalar", mx::add, a, b, device);
TIMEM("vector-scalar", mx::subtract, a, b, device);
TIMEM("scalar-vector", mx::subtract, b, a, device);
TIMEM("scalar", mx::multiply, a, b, device);
TIMEM("vector-scalar", mx::divide, a, b, device);
TIMEM("scalar-vector", mx::divide, b, a, device);
TIMEM("scalar-vector", mx::where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
mx::eval(a, b);
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
}
void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50;
auto a = random::uniform({M, N, O, P});
auto b = random::uniform({M, N, O, P});
auto device = default_device();
eval(a, b);
TIMEM("non-strided", add, a, b, device);
a = transpose(a, {1, 0, 2, 3});
b = transpose(b, {3, 2, 0, 1});
eval(a, b);
TIMEM("strided", add, a, b, device);
auto a = mx::random::uniform({M, N, O, P});
auto b = mx::random::uniform({M, N, O, P});
auto device = mx::default_device();
mx::eval(a, b);
TIMEM("non-strided", mx::add, a, b, device);
a = mx::transpose(a, {1, 0, 2, 3});
b = mx::transpose(b, {3, 2, 0, 1});
mx::eval(a, b);
TIMEM("strided", mx::add, a, b, device);
}
void time_comparisons() {
int M = 1000, N = 100, K = 10;
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(equal, a, b, device);
TIME(greater, a, b, device);
TIME(greater_equal, a, b, device);
TIME(less, a, b, device);
TIME(less_equal, a, b, device);
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::equal, a, b, device);
TIME(mx::greater, a, b, device);
TIME(mx::greater_equal, a, b, device);
TIME(mx::less, a, b, device);
TIME(mx::less_equal, a, b, device);
}
void time_matvec() {
int M = 2000, N = 200;
auto a = random::uniform({M, N});
auto b = random::uniform({N});
auto c = random::uniform({M});
eval(a, b, c);
auto matvec = [&]() { return matmul(a, b); };
auto a = mx::random::uniform({M, N});
auto b = mx::random::uniform({N});
auto c = mx::random::uniform({M});
mx::eval(a, b, c);
auto matvec = [&]() { return mx::matmul(a, b); };
TIME(matvec);
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
TIME(matvec_transpose);
}
void time_matmul() {
int M = 1000, N = 1000, K = 1000;
auto a = random::uniform({M, K});
auto b = random::uniform({K, N});
auto device = default_device();
eval(a, b);
TIME(matmul, a, b, device);
auto a = mx::random::uniform({M, K});
auto b = mx::random::uniform({K, N});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::matmul, a, b, device);
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
TIME(transpose_matmul);
}
void time_reductions() {
auto a = random::normal({10000, 1000});
eval(a);
auto sum_all = [&a]() { return sum(a, false); };
auto a = mx::random::normal({10000, 1000});
mx::eval(a);
auto sum_all = [&a]() { return mx::sum(a, false); };
TIME(sum_all);
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
TIME(sum_along_0);
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
TIME(sum_along_1);
auto prod_all = [&a]() { return prod(a, false); };
auto prod_all = [&a]() { return mx::prod(a, false); };
TIME(prod_all);
auto all_true = [&a]() { return all(a, false); };
auto all_true = [&a]() { return mx::all(a, false); };
TIME(all_true);
auto all_along_0 = [&a]() { return all(a, 0, false); };
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
TIME(all_along_0);
auto all_along_1 = [&a]() { return all(a, 1, false); };
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
TIME(all_along_1);
auto any_true = [&a]() { return any(a, false); };
auto any_true = [&a]() { return mx::any(a, false); };
TIME(any_true);
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
}
void time_gather_scatter() {
auto a = random::normal({1000, 768});
eval(a);
auto indices = random::randint(0, 1000, {256});
eval(indices);
auto a = mx::random::normal({1000, 768});
mx::eval(a);
auto indices = mx::random::randint(0, 1000, {256});
mx::eval(indices);
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
TIME(embedding_lookup);
indices = random::randint(0, 768 * 1000, {256 * 768});
eval(indices);
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
mx::eval(indices);
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
auto single_element_lookup = [&a, &indices]() {
return mx::take(a, indices);
};
TIME(single_element_lookup);
indices = random::randint(0, 1000, {256});
auto updates = random::normal({256, 1, 768});
eval(indices, updates);
indices = mx::random::randint(0, 1000, {256});
auto updates = mx::random::normal({256, 1, 768});
mx::eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -223,10 +225,10 @@ void time_gather_scatter() {
};
TIME(embedding_add);
a = reshape(a, {-1});
indices = random::randint(0, 768 * 1000, {768 * 256});
updates = random::normal({256 * 768, 1});
eval(a, indices, updates);
a = mx::reshape(a, {-1});
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
updates = mx::random::normal({256 * 768, 1});
mx::eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -240,21 +242,21 @@ void time_gather_scatter() {
}
void time_divmod() {
auto a = random::normal({1000});
auto b = random::normal({1000});
eval({a, b});
auto a = mx::random::normal({1000});
auto b = mx::random::normal({1000});
mx::eval({a, b});
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_creation_ops();
time_type_conversions();
time_unary_ops();

View File

@@ -0,0 +1,74 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
B = 1024
D = 1024
M = 4 * D
group_size = 64
bits = 4
dtype = mx.float16
loops = 10
def qmm_(x, wq1, wq2, q_type):
for i in range(loops):
x = mx.quantized_matmul(
x,
*wq1,
group_size=group_size,
bits=bits,
quantization_type=q_type,
)
x = mx.quantized_matmul(
x,
*wq2,
group_size=group_size,
bits=bits,
quantization_type=q_type,
)
return x
def affine_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine")
def affine_packed_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine-packed")
def time_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
mx.eval(x, wq1, wq2)
time_fn(affine_qmm, x, wq1, wq2)
def time_packed_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
mx.eval(x, wq1, wq2)
time_fn(affine_packed_qmm, x, wq1, wq2)
if __name__ == "__main__":
for b in [2, 4, 8]:
bits = b
print(f"Bits {bits}:")
time_qmm()
time_packed_qmm()

View File

@@ -1,94 +1,58 @@
import argparse
import math
import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn
L = 32768
L = 16384
H = 32
H_k = H // 4
D = 128
dtype = mx.float16
bits = 8
loops = 20
loops = 10
def attention(q, k, v):
for _ in range(loops):
def _sdpa(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
ke = k[:, :, None, :, :]
ve = v[:, :, None, :, :]
s = q @ ke.transpose(0, 1, 2, 4, 3)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
q = p @ ve
q = q.reshape(B, Hq, L, D)
o = p @ v
return o.reshape(B, Hq, L, D)
for i in range(loops):
q = _sdpa(q, k, v)
return q
def sdpa(q, k, v):
for _ in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
return q
def quant_sdpa(q, k, v, bits=4):
for _ in range(loops):
q = mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=bits
)
return q
def quant_attention(q, k, v, bits=4):
for _ in range(loops):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]
q = q.reshape((B, Hk, Hq // Hk, L, D))
ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)
q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits)
q = q.reshape((B, Hq, L, D))
return q
def time_self_attention_primitives(q, k, v):
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
time_fn(attention, q, k, v)
def time_self_attention_sdpa(q, k, v):
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)
def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v, bits)
if __name__ == "__main__":
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype)
k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
mx.eval(q, k, v)
k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)
k = mx.dequantize(*k_quant, bits=bits)
v = mx.dequantize(*v_quant, bits=bits)
time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@@ -420,8 +420,8 @@ element in the output.
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const size_t* x_strides [[buffer(6)]],
constant const size_t* y_strides [[buffer(7)]],
constant const int64_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array
@@ -438,24 +438,10 @@ each instantiation a unique host name so we can identify it.
.. code-block:: C++
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown

View File

@@ -168,6 +168,7 @@ Operations
tri
tril
triu
unflatten
var
view
where

View File

@@ -4,19 +4,19 @@
#include "mlx/mlx.h"
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
if (!distributed::is_available()) {
if (!mx::distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = distributed::init();
auto global_group = mx::distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10});
array out = distributed::all_sum(x, global_group);
mx::array x = mx::ones({10});
mx::array out = mx::distributed::all_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@@ -10,7 +10,7 @@
/**
* An example of linear regression with MLX.
*/
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.01;
// True parameters
auto w_star = random::normal({num_features});
auto w_star = mx::random::normal({num_features});
// The input examples (design matrix)
auto X = random::normal({num_examples, num_features});
auto X = mx::random::normal({num_examples, num_features});
// Noisy labels
auto eps = 1e-2 * random::normal({num_examples});
auto y = matmul(X, w_star) + eps;
auto eps = 1e-2 * mx::random::normal({num_examples});
auto y = mx::matmul(X, w_star) + eps;
// Initialize random parameters
array w = 1e-2 * random::normal({num_features});
mx::array w = 1e-2 * mx::random::normal({num_features});
auto loss_fn = [&](array w) {
auto yhat = matmul(X, w);
return (0.5f / num_examples) * sum(square(yhat - y));
auto loss_fn = [&](mx::array w) {
auto yhat = mx::matmul(X, w);
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
};
auto grad_fn = grad(loss_fn);
auto grad_fn = mx::grad(loss_fn);
auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) {
auto grad = grad_fn(w);
w = w - learning_rate * grad;
eval(w);
auto grads = grad_fn(w);
w = w - learning_rate * grads;
mx::eval(w);
}
auto toc = timer::time();
auto loss = loss_fn(w);
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
<< ", Throughput " << throughput << " (it/s)." << std::endl;

View File

@@ -10,7 +10,7 @@
/**
* An example of logistic regression with MLX.
*/
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.1;
// True parameters
auto w_star = random::normal({num_features});
auto w_star = mx::random::normal({num_features});
// The input examples
auto X = random::normal({num_examples, num_features});
auto X = mx::random::normal({num_examples, num_features});
// Labels
auto y = matmul(X, w_star) > 0;
auto y = mx::matmul(X, w_star) > 0;
// Initialize random parameters
array w = 1e-2 * random::normal({num_features});
mx::array w = 1e-2 * mx::random::normal({num_features});
auto loss_fn = [&](array w) {
auto logits = matmul(X, w);
auto loss_fn = [&](mx::array w) {
auto logits = mx::matmul(X, w);
auto scale = (1.0f / num_examples);
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
};
auto grad_fn = grad(loss_fn);
auto grad_fn = mx::grad(loss_fn);
auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) {
auto grad = grad_fn(w);
w = w - learning_rate * grad;
eval(w);
auto grads = grad_fn(w);
w = w - learning_rate * grads;
mx::eval(w);
}
auto toc = timer::time();
auto loss = loss_fn(w);
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
<< throughput << " (it/s)." << std::endl;

View File

@@ -5,27 +5,27 @@
#include "mlx/mlx.h"
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
metal::start_capture("mlx_trace.gputrace");
mx::metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto s2 = new_stream(mx::Device::gpu);
auto s3 = new_stream(mx::Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
auto x = mx::add(a, a, s2);
auto y = mx::add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl;
std::cout << mx::multiply(x, y) << std::endl;
metal::stop_capture();
mx::metal::stop_capture();
}

View File

@@ -5,11 +5,11 @@
#include "mlx/mlx.h"
using namespace mlx::core;
namespace mx = mlx::core;
void array_basics() {
// Make a scalar array:
array x(1.0);
mx::array x(1.0);
// Get the value out of it:
auto s = x.item<float>();
@@ -29,31 +29,31 @@ void array_basics() {
// The datatype should be float32:
auto dtype = x.dtype();
assert(dtype == float32);
assert(dtype == mx::float32);
// Specify the dtype when constructing the array:
x = array(1, int32);
assert(x.dtype() == int32);
x = mx::array(1, mx::int32);
assert(x.dtype() == mx::int32);
x.item<int>(); // OK
// x.item<float>(); // Undefined!
// Make a multidimensional array:
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
// mlx is row-major by default so the first row of this array
// is [1.0, 2.0] and the second row is [3.0, 4.0]
// Make an array of shape {2, 2} filled with ones:
auto y = ones({2, 2});
auto y = mx::ones({2, 2});
// Pointwise add x and y:
auto z = add(x, y);
auto z = mx::add(x, y);
// Same thing:
z = x + y;
// mlx is lazy by default. At this point `z` only
// has a shape and a type but no actual data:
assert(z.dtype() == float32);
assert(z.dtype() == mx::float32);
assert(z.shape(0) == 2);
assert(z.shape(1) == 2);
@@ -63,33 +63,33 @@ void array_basics() {
// and inputs. When `eval` is called on an array (or arrays), the array and
// all of its dependencies are recursively evaluated to produce the result.
// Once an array is evaluated, it has data and is detached from its inputs.
eval(z);
mx::eval(z);
// Of course the array can still be an input to other operations. You can even
// call eval on the array again, this will just be a no-op:
eval(z); // no-op
// Of course the array can still be an input to other operations. You can
// even call eval on the array again, this will just be a no-op:
mx::eval(z); // no-op
// Some functions or methods on arrays implicitly evaluate them. For example
// accessing a value in an array or printing the array implicitly evaluate it:
z = ones({1});
z = mx::ones({1});
z.item<float>(); // implicit evaluation
z = ones({2, 2});
z = mx::ones({2, 2});
std::cout << z << std::endl; // implicit evaluation
}
void automatic_differentiation() {
auto fn = [](array x) { return square(x); };
auto fn = [](mx::array x) { return mx::square(x); };
// Computing the derivative function of a function
auto grad_fn = grad(fn);
auto grad_fn = mx::grad(fn);
// Call grad_fn on the input to get the derivative
auto x = array(1.5);
auto x = mx::array(1.5);
auto dfdx = grad_fn(x);
// dfdx is 2 * x
// Get the second derivative by composing grad with grad
auto d2fdx2 = grad(grad(fn))(x);
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
// d2fdx2 is 2
}

View File

@@ -19,7 +19,7 @@
#include "mlx/backend/metal/utils.h"
#endif
namespace mlx::core {
namespace my_ext {
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
@@ -32,24 +32,24 @@ namespace mlx::core {
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
mx::array axpby(
const mx::array& x, // Input mx::array x
const mx::array& y, // Input mx::array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
? promoted_dtype
: promote_types(promoted_dtype, float32);
: promote_types(promoted_dtype, mx::float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
auto x_casted = mx::astype(x, out_dtype, s);
auto y_casted = mx::astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
@@ -57,12 +57,12 @@ array axpby(
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
return mx::array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
/* mx::Dtype dtype = */ out_dtype,
/* std::unique_ptr<mx::Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
}
///////////////////////////////////////////////////////////////////////////////
@@ -71,16 +71,16 @@ array axpby(
template <typename T>
void axpby_impl(
const array& x,
const array& y,
array& out,
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
@@ -94,8 +94,8 @@ void axpby_impl(
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
@@ -105,8 +105,8 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
@@ -114,14 +114,14 @@ void Axpby::eval(
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
@@ -136,9 +136,9 @@ void Axpby::eval(
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
@@ -150,10 +150,10 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
copy_inplace(y, out, mx::CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
@@ -175,15 +175,15 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
if (out.dtype() == mx::float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
@@ -198,8 +198,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
eval(inputs, outputs);
}
@@ -213,8 +213,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
@@ -225,7 +225,7 @@ void Axpby::eval_gpu(
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
auto& d = mx::metal::device(s.device);
// Prepare to specialize based on contiguity
bool contiguous_kernel =
@@ -235,12 +235,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)
@@ -302,8 +302,8 @@ void Axpby::eval_gpu(
/** Fail evaluation on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& out) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& out) {
throw std::runtime_error("Axpby has no GPU implementation.");
}
@@ -314,9 +314,9 @@ void Axpby::eval_gpu(
///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
std::vector<mx::array> Axpby::jvp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
@@ -328,8 +328,8 @@ std::vector<array> Axpby::jvp(
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
auto scale_arr = mx::array(scale, tangents[0].dtype());
return {mx::multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
@@ -339,24 +339,24 @@ std::vector<array> Axpby::jvp(
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
std::vector<mx::array> Axpby::vjp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<mx::array>&) {
// Reverse mode diff
std::vector<array> vjps;
std::vector<mx::array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = mx::array(scale, cotangents[0].dtype());
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");
}
@@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace mlx::core
} // namespace my_ext

View File

@@ -5,7 +5,9 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace mx = mlx::core;
namespace my_ext {
///////////////////////////////////////////////////////////////////////////////
// Operation
@@ -18,22 +20,22 @@ namespace mlx::core {
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
mx::array axpby(
const mx::array& x, // Input array x
const mx::array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
);
///////////////////////////////////////////////////////////////////////////////
// Primitive
///////////////////////////////////////////////////////////////////////////////
class Axpby : public Primitive {
class Axpby : public mx::Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta) {};
explicit Axpby(mx::Stream stream, float alpha, float beta)
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
@@ -42,23 +44,25 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
void eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
std::vector<mx::array> jvp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
std::vector<mx::array> vjp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<mx::array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
@@ -66,8 +70,8 @@ class Axpby : public Primitive {
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
@@ -76,14 +80,16 @@ class Axpby : public Primitive {
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
bool is_equivalent(const mx::Primitive& other) const override;
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
void eval(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs);
};
} // namespace mlx::core
} // namespace my_ext

View File

@@ -12,8 +12,8 @@ template <typename T>
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const size_t* x_strides [[buffer(6)]],
constant const size_t* y_strides [[buffer(7)]],
constant const int64_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
@@ -34,29 +34,14 @@ template <typename T>
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
}
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
// clang-format off
#define instantiate_axpby(type_name, type) \
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \
instantiate_kernel( \
"axpby_contiguous_" #type_name, axpby_contiguous, type)
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
// clang-format on

View File

@@ -8,14 +8,12 @@
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
m.def(
"axpby",
&axpby,
&my_ext::axpby,
"x"_a,
"y"_a,
"alpha"_a,

View File

@@ -18,6 +18,16 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
if(MSVC)
# Disable some MSVC warnings to speed up compilation.
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
endif()
if(WIN32)
# Export symbols by default to behave like macOS/linux.
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
endif()
if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()

View File

@@ -31,7 +31,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
}
array::array(
std::vector<int> shape,
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
@@ -42,7 +42,7 @@ array::array(
std::move(inputs))) {}
std::vector<array> array::make_arrays(
std::vector<std::vector<int>> shapes,
std::vector<Shape> shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
@@ -74,11 +74,7 @@ array::array(std::initializer_list<int> data, Dtype dtype)
}
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
std::vector<int> shape,
Dtype dtype,
deleter_t deleter)
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter);
}
@@ -126,7 +122,7 @@ bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph();
}
void array::set_data(allocator::Buffer buffer, deleter_t d) {
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size();
@@ -139,9 +135,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) {
void array::set_data(
allocator::Buffer buffer,
size_t data_size,
std::vector<size_t> strides,
Strides strides,
Flags flags,
deleter_t d) {
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size;
@@ -151,7 +147,7 @@ void array::set_data(
void array::copy_shared_buffer(
const array& other,
const std::vector<size_t>& strides,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@@ -170,7 +166,7 @@ void array::copy_shared_buffer(const array& other) {
void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@@ -237,13 +233,13 @@ void array::ArrayDesc::init() {
}
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
}
array::ArrayDesc::ArrayDesc(
std::vector<int> shape,
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)

View File

@@ -15,7 +15,10 @@ namespace mlx::core {
// Forward declaration
class Primitive;
using deleter_t = std::function<void(allocator::Buffer)>;
using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>;
using Strides = std::vector<int64_t>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc
@@ -33,7 +36,7 @@ class array {
template <typename It>
array(
It data,
std::vector<int> shape,
Shape shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -49,15 +52,15 @@ class array {
template <typename T>
array(
std::initializer_list<T> data,
std::vector<int> shape,
Shape shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
std::vector<int> shape,
Shape shape,
Dtype dtype,
deleter_t deleter = allocator::free);
Deleter deleter = allocator::free);
/** Assignment to rvalue does not compile. */
array& operator=(const array& other) && = delete;
@@ -96,7 +99,7 @@ class array {
}
/** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const {
const Shape& shape() const {
return array_desc_->shape;
}
@@ -105,12 +108,12 @@ class array {
*
* This function supports negative indexing and provides
* bounds checking. */
int shape(int dim) const {
auto shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim);
}
/** The strides of the array. */
const std::vector<size_t>& strides() const {
const Strides& strides() const {
return array_desc_->strides;
}
@@ -119,7 +122,7 @@ class array {
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
auto strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
}
@@ -184,13 +187,13 @@ class array {
*/
array(
std::vector<int> shape,
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
static std::vector<array> make_arrays(
std::vector<std::vector<int>> shapes,
std::vector<Shape> shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs);
@@ -207,8 +210,8 @@ class array {
struct Data {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
Deleter d;
Data(allocator::Buffer buffer, Deleter d = allocator::free)
: buffer(buffer), d(d) {}
// Not copyable
Data(const Data& d) = delete;
@@ -397,18 +400,18 @@ class array {
// Check if the array is a tracer array
bool is_tracer() const;
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
void set_data(
allocator::Buffer buffer,
size_t data_size,
std::vector<size_t> strides,
Strides strides,
Flags flags,
deleter_t d = allocator::free);
Deleter d = allocator::free);
void copy_shared_buffer(
const array& other,
const std::vector<size_t>& strides,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
@@ -417,7 +420,7 @@ class array {
void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
@@ -436,8 +439,8 @@ class array {
void init(const It src);
struct ArrayDesc {
std::vector<int> shape;
std::vector<size_t> strides;
Shape shape;
Strides strides;
size_t size;
Dtype dtype;
std::shared_ptr<Primitive> primitive;
@@ -471,10 +474,10 @@ class array {
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
explicit ArrayDesc(Shape shape, Dtype dtype);
explicit ArrayDesc(
std::vector<int> shape,
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
@@ -502,7 +505,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It>
array::array(
It data,
std::vector<int> shape,
Shape shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
init(data);
@@ -521,7 +524,7 @@ array::array(
template <typename T>
array::array(
std::initializer_list<T> data,
std::vector<int> shape,
Shape shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) {

View File

@@ -43,6 +43,7 @@ DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(ExpandDims)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
@@ -65,7 +66,6 @@ DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
@@ -76,6 +76,7 @@ DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(Squeeze)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)

View File

@@ -5,13 +5,21 @@ else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
if(MSVC)
set(SHELL_EXT ps1)
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
else()
set(SHELL_EXT sh)
set(SHELL_CMD /bin/bash)
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG}
DEPENDS make_compiled_preamble.sh
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
DEPENDS make_compiled_preamble.${SHELL_EXT}
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h

View File

@@ -13,8 +13,8 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
std::vector<size_t> strides = in.strides();
std::vector<int> shape = in.shape();
Strides strides = in.strides();
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
for (uint32_t i = 0; i < out.size(); ++i) {

View File

@@ -178,10 +178,10 @@ void binary_op_dims(
const T* b,
U* out,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
@@ -212,10 +212,10 @@ void binary_op_dispatch_dims(
array& out,
Op op,
int dim,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides) {
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
@@ -258,10 +258,10 @@ void binary_op_dispatch_dims(
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
size_t stride = out_strides[dim - 4];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
@@ -327,7 +327,7 @@ void binary_op(
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
@@ -337,7 +337,7 @@ void binary_op(
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}

View File

@@ -16,10 +16,10 @@ void binary_op_dims(
U* out_a,
U* out_b,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
@@ -96,9 +96,9 @@ void binary_op_dispatch_dims(
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,

View File

@@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
@@ -85,6 +85,16 @@ void Depends::eval(
}
}
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto strides = in.strides();
for (auto ax : axes_) {
strides.insert(strides.begin() + ax, 1);
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -141,9 +151,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
}
}
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
@@ -151,8 +159,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
return {false, Strides(out.ndim(), 0)};
}
// Firstly let's collapse all the contiguous dimensions of the input
@@ -160,7 +167,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
Strides out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
@@ -181,9 +188,9 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
void shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
const Strides& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
@@ -249,16 +256,18 @@ void Split::eval(
}
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
Strides strides;
for (int i = 0, j = 0; i < in.ndim(); ++i) {
if (j < axes_.size() && i == axes_[j]) {
j++;
} else {
strides.push_back(in.strides(i));
}
}
return std::make_tuple(data_offset, inp_strides);
move_or_copy(in, out, strides, in.flags(), in.data_size());
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
@@ -268,7 +277,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
Strides out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
@@ -285,8 +294,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {

View File

@@ -165,7 +165,7 @@ void compiled_allocate_outputs(
bool move_buffers /* = false */) {
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
Strides strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {

View File

@@ -746,9 +746,9 @@ void explicit_gemm_conv_1D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape = {N, oH, wH, C};
Shape strided_shape = {N, oH, wH, C};
std::vector<size_t> strided_strides = {
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[1],
@@ -865,9 +865,9 @@ void explicit_gemm_conv_2D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
Shape strided_shape = {N, oH, oW, wH, wW, C};
std::vector<size_t> strided_strides = {
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
@@ -974,7 +974,7 @@ void explicit_gemm_conv_ND_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
Shape strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N;
for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i];
@@ -984,7 +984,7 @@ void explicit_gemm_conv_ND_cpu(
}
strided_shape.back() = C;
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
Strides strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0];
for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
@@ -1000,7 +1000,7 @@ void explicit_gemm_conv_ND_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N, C};
Shape strided_reshape = {N, C};
for (const auto& o : oDim) {
strided_reshape[0] *= o;
}

View File

@@ -26,13 +26,13 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
template <typename SrcT, typename DstT, typename StrideT, int D>
template <typename SrcT, typename DstT, int D>
inline void copy_dims(
const SrcT* src,
DstT* dst,
const std::vector<int>& shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
const Shape& shape,
const Strides& i_strides,
const Strides& o_strides,
int axis) {
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
@@ -40,7 +40,7 @@ inline void copy_dims(
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
copy_dims<SrcT, DstT, StrideT, D - 1>(
copy_dims<SrcT, DstT, D - 1>(
src, dst, shape, i_strides, o_strides, axis + 1);
} else {
*dst = static_cast<DstT>(*src);
@@ -50,13 +50,13 @@ inline void copy_dims(
}
}
template <typename SrcT, typename DstT, typename StrideT>
template <typename SrcT, typename DstT>
void copy_general_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset) {
if (data_shape.empty()) {
@@ -65,30 +65,30 @@ void copy_general_general(
*dst_ptr = val;
return;
}
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size();
if (ndim == 1) {
copy_dims<SrcT, DstT, StrideT, 1>(
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, StrideT, 2>(
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, StrideT, 3>(
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
}
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
StrideT stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
for (StrideT elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, StrideT, 3>(
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
@@ -102,37 +102,37 @@ void copy_general_general(
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>(
copy_general_general<SrcT, DstT>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT, typename StrideT>
template <typename SrcT, typename DstT>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>&,
const Shape& data_shape,
const Strides& i_strides,
const Strides&,
int64_t i_offset,
int64_t o_offset) {
copy_general_general<SrcT, DstT, StrideT>(
copy_general_general<SrcT, DstT>(
src,
dst,
data_shape,
i_strides,
make_contiguous_strides<StrideT>(data_shape),
make_contiguous_strides(data_shape),
i_offset,
o_offset);
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>(
copy_general_general<SrcT, DstT>(
src,
dst,
src.shape(),
src.strides(),
make_contiguous_strides<size_t>(src.shape()),
make_contiguous_strides(src.shape()),
0,
0);
}
@@ -282,13 +282,12 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype);
}
template <typename StrideT>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
@@ -311,24 +310,4 @@ void copy_inplace(
}
}
template void copy_inplace<size_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core

View File

@@ -26,13 +26,12 @@ enum class CopyType {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);

View File

@@ -57,6 +57,7 @@ DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(ExpandDims)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
@@ -86,7 +87,6 @@ DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
@@ -101,6 +101,7 @@ DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Squeeze)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)
@@ -130,7 +131,7 @@ inline void matmul_common_general(
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};

View File

@@ -32,7 +32,7 @@ void gather(
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes) {
const Shape& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
@@ -80,11 +80,10 @@ void gather(
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator src_it;
if (!can_copy && src.ndim() > 0) {
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
@@ -119,7 +118,7 @@ void dispatch_gather(
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const std::vector<int>& size) {
const Shape& size) {
switch (out.dtype()) {
case bool_:
gather<bool, IdxT>(src, inds, out, axes, size);
@@ -223,16 +222,16 @@ void scatter(
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
std::vector<int> update_shape(
Shape update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;

View File

@@ -2,6 +2,15 @@
#pragma once
// Required for Visual Studio.
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
#ifdef _MSC_VER
#include <complex>
#define LAPACK_COMPLEX_CUSTOM
#define lapack_complex_float std::complex<float>
#define lapack_complex_double std::complex<double>
#endif
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else

View File

@@ -0,0 +1,38 @@
# This script generates a C++ function that provides the CPU
# code for use with kernel generation.
#
# Copyright © 2024 Apple Inc.
$OUTPUT_FILE = $args[0]
$CL = $args[1]
$SRCDIR = $args[2]
# Get command result as array.
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
# Remove empty lines.
# Otherwise there will be too much empty lines making the result unreadable.
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
# Concatenate to string.
$CONTENT = $CONTENT -join '`n'
# Append extra content.
$CONTENT = @"
$($CONTENT)
using namespace mlx::core;
using namespace mlx::core::detail;
"@
# Convert each char to ASCII code.
# Unlike the unix script that outputs string literal directly, the output from
# MSVC is way too large to be embedded as string and compilation will fail, so
# we store it as static array instead.
$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'
$OUTPUT = @"
const char* get_kernel_preamble() {
static char preamble[] = { $CHARCODES };
return preamble;
}
"@
Set-Content -Path $OUTPUT_FILE -Value $OUTPUT

View File

@@ -10,15 +10,16 @@ OUTPUT_FILE=$1
GCC=$2
SRCDIR=$3
CLANG=$4
ARCH=$5
if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
EOM
CC_FLAGS=""
CC_FLAGS="-arch ${ARCH}"
else
CC_FLAGS="-std=c++17"
fi

View File

@@ -19,10 +19,10 @@ inline void mask_matrix(
int block_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str,
const int64_t X_data_str,
const int64_t Y_data_str,
const int64_t X_mask_str,
const int64_t Y_mask_str,
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
@@ -84,7 +84,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
@@ -117,13 +117,13 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
int Y,
size_t X_data_str,
size_t Y_data_str) {
size_t mask_offset = elem_to_loc(
auto mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
auto X_mask_str = mask.strides()[mask.ndim() - 2];
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
if (mask.dtype() == bool_) {
return mask_matrix(
@@ -230,7 +230,7 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
@@ -262,13 +262,13 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides());
auto batch_shape_B = get_batch_dims(b.shape());
auto batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();

View File

@@ -500,7 +500,12 @@ struct Equal {
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (std::isnan(x) && std::isnan(y));
if constexpr (std::is_integral_v<T>) {
// isnan always returns false for integers, and MSVC refuses to compile.
return x == y;
} else {
return x == y || (std::isnan(x) && std::isnan(y));
}
}
};

View File

@@ -19,6 +19,16 @@
namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -258,6 +268,14 @@ void Expm1::eval(const std::vector<array>& inputs, array& out) {
}
}
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -417,18 +435,8 @@ void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Round::eval(const std::vector<array>& inputs, array& out) {
@@ -498,34 +506,17 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices_, strides_);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_size = data_end - data_offset;
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
@@ -550,11 +541,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace<int64_t>(
Strides upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),

View File

@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r) {
// Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false;
std::vector<size_t> strides = in.strides();
auto strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(

View File

@@ -174,19 +174,19 @@ void reduce_dispatch_min_max(
void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
const Shape& shape,
const Strides& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
int size = shape[dim];
size_t stride = strides[dim];
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
int size = shape[dim];
size_t stride = strides[dim];
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}

View File

@@ -38,13 +38,10 @@ enum ReductionOpType {
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
Shape shape;
Strides strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
@@ -55,10 +52,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Should this be in utils?
void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides);
const Shape& shape,
const Strides& strides);
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
@@ -113,9 +110,6 @@ void reduction_op(
return;
}
std::vector<int> shape;
std::vector<size_t> strides;
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
@@ -135,7 +129,7 @@ void reduction_op(
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
@@ -181,7 +175,7 @@ void reduction_op(
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
@@ -211,7 +205,7 @@ void reduction_op(
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;

View File

@@ -4,11 +4,11 @@
namespace mlx::core {
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
std::vector<int> shape = x.shape();
std::vector<size_t> strides = x.strides();
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
@@ -29,8 +29,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]};
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
@@ -69,7 +69,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
std::vector<std::pair<int, int64_t>> reductions;
for (auto a : axes) {
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
@@ -93,8 +93,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
}
}
std::vector<int> shape;
std::vector<size_t> strides;
Shape shape;
Strides strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
@@ -109,15 +109,15 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
int64_t size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
size_t stride_i = x.strides()[i];
int shape_i = x.shape(i);
auto stride_i = x.strides()[i];
auto shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;

View File

@@ -4,24 +4,22 @@
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
std::tuple<int64_t, Strides> prepare_slice(
const array& in,
const std::vector<int>& start_indices,
const std::vector<int>& strides) {
const Shape& start_indices,
const Shape& strides) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
Strides inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
return std::make_tuple(data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
const Strides& out_strides,
size_t data_offset,
size_t data_size,
array& out) {

View File

@@ -6,14 +6,14 @@
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
std::tuple<int64_t, Strides> prepare_slice(
const array& in,
const std::vector<int>& start_indices,
const std::vector<int>& strides);
const Shape& start_indices,
const Shape& strides);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
const Strides& out_strides,
size_t data_offset,
size_t data_size,
array& out);

View File

@@ -25,7 +25,7 @@ struct StridedIterator {
// Constructors
StridedIterator() = default;
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
@@ -99,7 +99,7 @@ struct StridedIterator {
}
private:
size_t stride_;
int64_t stride_;
T* ptr_;
};
@@ -120,11 +120,11 @@ void sort(const array& in, array& out, int axis) {
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
auto axis_stride = out.strides()[axis];
auto axis_size = out.shape(axis);
// Perform sorting in place
ContiguousIterator<size_t> src_it(
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
@@ -158,14 +158,14 @@ void argsort(const array& in, array& out, int axis) {
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
auto in_stride = in.strides()[axis];
auto out_stride = out.strides()[axis];
auto axis_size = in.shape(axis);
// Perform sorting
ContiguousIterator<size_t> in_it(
ContiguousIterator in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
@@ -208,13 +208,13 @@ void partition(const array& in, array& out, int axis, int kth) {
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
auto axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator<size_t> src_it(
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
@@ -249,16 +249,16 @@ void argpartition(const array& in, array& out, int axis, int kth) {
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
auto in_stride = in.strides()[axis];
auto out_stride = out.strides()[axis];
auto axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition
ContiguousIterator<size_t> in_it(
ContiguousIterator in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;

View File

@@ -78,11 +78,11 @@ void ternary_op_dims(
const T3* c,
U* out,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& c_strides,
const std::vector<size_t>& out_strides,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& c_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
@@ -164,10 +164,10 @@ void ternary_op_dispatch_dims(
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator c_it(shape, c_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,

View File

@@ -15,7 +15,7 @@ void move_or_copy(const array& in, array& out) {
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@@ -26,15 +26,13 @@ void move_or_copy(
}
}
template <typename StrideT>
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<std::vector<StrideT>>& strides,
StrideT size_cap) {
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,
int64_t size_cap) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
Shape to_collapse;
if (shape.size() > 0) {
if (shape[0] != 1) {
to_collapse.push_back(0);
@@ -43,7 +41,7 @@ collapse_contiguous_dims_impl(
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
for (const std::vector<StrideT>& st : strides) {
for (const auto& st : strides) {
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
contiguous = false;
size = shape[i];
@@ -60,8 +58,8 @@ collapse_contiguous_dims_impl(
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<StrideT>> out_strides(strides.size());
Shape out_shape;
std::vector<Strides> out_strides(strides.size());
for (int i = 0;;) {
while (i < to_collapse.size() && to_collapse[i] == -1) {
++i;
@@ -76,7 +74,7 @@ collapse_contiguous_dims_impl(
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<StrideT>& st = strides[j];
const auto& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}
i = k + 1;
@@ -91,29 +89,12 @@ collapse_contiguous_dims_impl(
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
template <typename StrideT>
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
StrideT size_cap) {
std::vector<int> collapsed_shape;
std::vector<StrideT> collapsed_strides;
std::pair<Shape, Strides> collapse_contiguous_dims(
const Shape& shape,
const Strides& strides,
int64_t size_cap) {
Shape collapsed_shape;
Strides collapsed_strides;
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
@@ -123,7 +104,7 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
continue;
} else if (
strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
collapsed_shape.back() * static_cast<int64_t>(shape[i]) > size_cap) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
@@ -136,25 +117,10 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
return std::make_pair(collapsed_shape, collapsed_strides);
}
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
std::pair<Shape, Strides> collapse_contiguous_dims(
const array& a,
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims_impl<size_t>(
a.shape(), a.strides(), size_cap);
int64_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
}
} // namespace mlx::core

View File

@@ -8,12 +8,9 @@
namespace mlx::core {
template <typename StrideT>
inline StrideT elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<StrideT>& strides) {
StrideT loc = 0;
inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int64_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
@@ -22,16 +19,15 @@ inline StrideT elem_to_loc(
return loc;
}
inline size_t elem_to_loc(int elem, const array& a) {
inline int64_t elem_to_loc(int elem, const array& a) {
if (a.flags().row_contiguous) {
return elem;
}
return elem_to_loc(elem, a.shape(), a.strides());
}
template <typename StrideT>
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<StrideT> strides(shape.size(), 1);
inline Strides make_contiguous_strides(const Shape& shape) {
Strides strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i];
}
@@ -44,22 +40,15 @@ std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const std::vector<array>& xs,
size_t size_cap = std::numeric_limits<int32_t>::max()) {
std::vector<std::vector<size_t>> strides;
std::vector<Strides> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
@@ -73,19 +62,14 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
}
// The single array version of the above.
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
std::pair<Shape, Strides> collapse_contiguous_dims(
const Shape& shape,
const Strides& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
std::pair<Shape, Strides> collapse_contiguous_dims(
const array& a,
size_t size_cap = std::numeric_limits<int32_t>::max());
int64_t size_cap = std::numeric_limits<int32_t>::max());
template <typename StrideT>
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();
@@ -102,7 +86,7 @@ struct ContiguousIterator {
loc += strides_[i];
}
void seek(StrideT n) {
void seek(int64_t n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]);
@@ -128,32 +112,29 @@ struct ContiguousIterator {
}
explicit ContiguousIterator(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
const Shape& shape,
const Strides& strides,
int dims)
: shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
pos_ = Shape(shape_.size(), 0);
}
}
StrideT loc{0};
int64_t loc{0};
private:
std::vector<int> shape_;
std::vector<StrideT> strides_;
std::vector<int> pos_;
Shape shape_;
Strides strides_;
Shape pos_;
};
template <typename StrideT>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<StrideT>& strides) {
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
size_t no_broadcast_data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
bool is_row_contiguous = true;
bool is_col_contiguous = true;
@@ -182,9 +163,15 @@ void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
} // namespace mlx::core

View File

@@ -75,19 +75,21 @@ void binary_op_gpu_inplace(
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
} else {
std::vector<size_t> e;
return std::make_tuple(std::vector<int>{}, e, e, e);
decltype(a.strides()) e{};
return std::make_tuple(decltype(a.shape()){}, e, e, e);
}
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool large = out.data_size() > UINT32_MAX;
bool large;
auto ndim = shape.size();
int work_per_thread;
if (bopt == BinaryOpType::General) {
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out.size() > INT32_MAX;
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = 1;
}
std::string kernel_name =

View File

@@ -67,7 +67,7 @@ inline void build_kernel(
if (add_indices) {
os += fmt::format(
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++);
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
}
// Add the output arguments
@@ -81,7 +81,7 @@ inline void build_kernel(
// Add output strides and shape to extract the indices.
if (!contiguous) {
os += fmt::format(
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++);
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
}
@@ -93,11 +93,11 @@ inline void build_kernel(
os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n";
std::string idx_type = use_big_index ? "size_t" : "uint";
std::string idx_type = use_big_index ? "int64_t" : "uint";
if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n";
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
@@ -144,20 +144,18 @@ inline void build_kernel(
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset);
os +=
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
} else if (ndim == 2) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n",
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
idx_type,
offset);
} else if (ndim == 3) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n",
idx_type,
offset);
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
} else if (!dynamic_dims) {
int offset = (i + 1) * ndim;
os += fmt::format(
@@ -360,10 +358,10 @@ void Compiled::eval_gpu(
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
std::vector<std::vector<size_t>> initial_strides;
std::vector<Strides> initial_strides;
initial_strides.push_back(outputs[0].strides());
std::vector<int> shape;
std::vector<std::vector<size_t>> strides;
Shape shape;
std::vector<Strides> strides;
if (!contiguous) {
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
@@ -378,7 +376,7 @@ void Compiled::eval_gpu(
}
// Broadcast the inputs to the output shape.
std::vector<size_t> xstrides;
Strides xstrides;
int j = 0;
for (; j < output_shape.size() - x.ndim(); j++) {
if (output_shape[j] == 1) {
@@ -440,7 +438,7 @@ void Compiled::eval_gpu(
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
std::vector<size_t> in_strides;
Strides in_strides;
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;

View File

@@ -64,8 +64,8 @@ void explicit_gemm_conv_ND_gpu(
compute_encoder.dispatch_threads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
Shape wt_reshape{implicit_K, implicit_N};
Strides wt_restride{1, implicit_K};
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
auto wt_flags = wt.flags();
wt_flags.row_contiguous = false;
@@ -147,10 +147,7 @@ void explicit_gemm_conv_group_ND_gpu(
array wt_view(
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt,
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
wt.flags(),
wt.size());
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
// Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});

View File

@@ -43,13 +43,12 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& strides_in_pre,
const std::vector<stride_t>& strides_out_pre,
const Shape& data_shape,
const Strides& strides_in_pre,
const Strides& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
@@ -68,8 +67,8 @@ void copy_gpu_inplace(
/* size_cap = */ INT32_MAX);
return std::make_tuple(shape, strides[0], strides[1]);
} else {
std::vector<stride_t> e;
return std::make_tuple(std::vector<int>{}, e, e);
Strides e{};
return std::make_tuple(Shape{}, e, e);
}
};
auto [shape, strides_in_, strides_out_] = maybe_collapse();
@@ -124,8 +123,8 @@ void copy_gpu_inplace(
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
Strides strides_in{strides_in_.begin(), strides_in_.end()};
Strides strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) {
compute_encoder.set_vector_bytes(shape, ndim, 2);
}
@@ -180,14 +179,13 @@ void copy_gpu_inplace(
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
const Strides& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
return copy_gpu_inplace(
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s);
}
void fill_gpu(const array& val, array& out, const Stream& s) {

View File

@@ -8,13 +8,12 @@
namespace mlx::core {
// Generic copy inplace
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype,
@@ -32,7 +31,7 @@ void copy_gpu_inplace(
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
const Strides& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s);

View File

@@ -363,7 +363,7 @@ void multi_upload_bluestein_fft(
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
// Broadcast w_q and w_k to the batch size
std::vector<size_t> b_strides(in.ndim(), 0);
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
@@ -386,8 +386,8 @@ void multi_upload_bluestein_fft(
copies.push_back(slice_temp);
copies.push_back(conj_temp);
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
Shape rstarts(in.ndim(), 0);
Shape rstrides(in.ndim(), 1);
rstarts[axis] = in.shape(axis) - back_offset;
rstrides[axis] = -1;
unary_op_gpu({in}, conj_temp, "Conjugate", s);
@@ -431,19 +431,19 @@ void multi_upload_bluestein_fft(
s);
int offset = plan.bluestein_n - (2 * n - 1);
std::vector<int> starts(in.ndim(), 0);
std::vector<int> strides(in.ndim(), 1);
Shape starts(in.ndim(), 0);
Shape strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) {
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
Shape rstarts(in.ndim(), 0);
Shape rstrides(in.ndim(), 1);
slice_gpu(temp1, out, rstarts, strides, s);
} else if (real && inverse) {
std::vector<size_t> b_strides(in.ndim(), 0);
Strides b_strides(in.ndim(), 0);
auto inv_n = array({1.0f / n}, {1}, float32);
array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float);
@@ -531,8 +531,8 @@ void fft_op(
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
std::vector<size_t> strides;
size_t cur_stride = x.shape(axis);
Strides strides;
int64_t cur_stride = x.shape(axis);
for (int a = 0; a < x.ndim(); a++) {
if (a == axis) {
strides.push_back(1);
@@ -777,7 +777,7 @@ void nd_fft_op(
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);

View File

@@ -53,9 +53,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
bool large_src = src.size() > UINT32_MAX;
bool large_out = out.size() > UINT32_MAX;
bool large_index = nidx && inputs[1].size() > INT32_MAX;
bool large_src = src.size() > INT32_MAX;
bool large_out = out.size() > INT32_MAX;
bool large = large_index || large_src || large_out;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
@@ -65,7 +65,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_type_name,
nidx,
idx_ndim,
large ? "size_t" : "uint");
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
@@ -86,7 +86,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args,
idx_arr,
idx_ndim,
large ? "size_t" : "uint");
large ? "int64_t" : "int");
return kernel_source;
});
@@ -234,9 +234,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
auto upd_contig = upd.flags().row_contiguous;
bool large_out = out.size() > UINT32_MAX;
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
bool large_upd = upd.size() > UINT32_MAX;
bool large_out = out.size() > INT32_MAX;
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
bool large_upd = upd.size() > INT32_MAX;
bool large = large_out || large_idx || large_upd;
std::string kernel_name = fmt::format(
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
@@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx,
upd_contig ? "updc_true" : "updc_false",
nwork,
large ? "size_t" : "uint");
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
@@ -290,7 +290,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_arr,
upd_contig,
nwork,
large ? "size_t" : "uint");
large ? "int64_t" : "int");
return kernel_source;
});
@@ -312,8 +312,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
upd_size *= upd.shape(i);
}
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
Shape idx_shapes;
Strides idx_strides;
// To access .data() use char instead of bool
// bool is 1 byte in Metal so this is safe
std::vector<char> idx_contigs;
@@ -332,7 +332,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 3);
compute_encoder.set_bytes(stride_, 4);
} else {
@@ -347,7 +347,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 7);
compute_encoder.set_bytes(stride_, 8);
} else {

View File

@@ -11,13 +11,13 @@ gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const device {outm_t}* out_mask [[buffer(20)]],
const device {opm_t}* mat_mask [[buffer(21)]],
const device {opm_t}* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
const constant int64_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],

View File

@@ -5,12 +5,12 @@ constexpr std::string_view gather_kernels = R"(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant int64_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant int64_t* idx_strides [[buffer(8)]],
const constant bool* idx_contigs [[buffer(9)]],
const constant int& idx_ndim [[buffer(10)]],
{4}
@@ -38,15 +38,15 @@ constexpr std::string_view scatter_kernels = R"(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant int64_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant int64_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant int64_t* idx_strides [[buffer(12)]],
const constant bool* idx_contigs [[buffer(13)]],
const constant int& idx_ndim [[buffer(14)]],
const constant size_t& idx_size [[buffer(15)]],

View File

@@ -10,12 +10,12 @@ template [[host_name("{name}")]]
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
@@ -43,7 +43,7 @@ block_masked_gemm<
device {itype}* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant int64_t* batch_strides [[buffer(7)]],
const device {outmasktype}* out_mask [[buffer(10)]],
const device {opmasktype}* lhs_mask [[buffer(11)]],
const device {opmasktype}* rhs_mask [[buffer(12)]],

View File

@@ -52,7 +52,7 @@ MTL::ComputePipelineState* get_unary_kernel(
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source;
@@ -74,7 +74,7 @@ void append_binary_kernels(
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"},
{"g1large", "binary_g_nd1"},
{"g2large", "binary_g_nd2"},
{"g3large", "binary_g_nd3"},
}};
@@ -86,11 +86,13 @@ void append_binary_kernels(
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
}
kernel_source += get_template_definition(
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
}
@@ -141,7 +143,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g1", "ternary_g_nd1"},
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
}};
@@ -150,11 +152,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
kernel_source += get_template_definition(
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
return kernel_source;
@@ -178,7 +182,7 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
@@ -186,19 +190,23 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source += get_template_definition(
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"g1large_" + lib_name, "copy_g_nd1", in_type, out_type);
kernel_source += get_template_definition(
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
kernel_source += get_template_definition(
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
kernel_source += get_template_definition(
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
kernel_source += get_template_definition(
"gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
kernel_source += get_template_definition(

View File

@@ -75,10 +75,10 @@ template <typename T, typename Op, int N_READS = 4>
const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]],
const constant int* shape [[buffer(2)]],
const constant size_t* in_strides [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant int64_t* in_strides [[buffer(3)]],
const constant int64_t* out_strides [[buffer(4)]],
const constant size_t& ndim [[buffer(5)]],
const constant size_t& axis_stride [[buffer(6)]],
const constant int64_t& axis_stride [[buffer(6)]],
const constant size_t& axis_size [[buffer(7)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],

View File

@@ -43,7 +43,7 @@ template <typename T, typename U, typename Op>
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[0], b[offset]);
}
@@ -54,7 +54,7 @@ template <typename T, typename U, typename Op>
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[0]);
}
@@ -65,49 +65,49 @@ template <typename T, typename U, typename Op>
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
constant const int64_t& a_stride,
constant const int64_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const int64_t a_strides[2],
constant const int64_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const int64_t a_strides[3],
constant const int64_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
@@ -117,18 +117,18 @@ template <
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
typename IdxT = int64_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);

View File

@@ -9,21 +9,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \

View File

@@ -56,7 +56,7 @@ template <typename T, typename U, typename Op>
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
@@ -70,7 +70,7 @@ template <typename T, typename U, typename Op>
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
@@ -84,58 +84,58 @@ template <typename T, typename U, typename Op>
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
constant const int64_t& a_stride,
constant const int64_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const int64_t a_strides[2],
constant const int64_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, typename IdxT = size_t>
template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const int64_t a_strides[3],
constant const int64_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
@@ -147,19 +147,19 @@ template <
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
typename IdxT = int64_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);

View File

@@ -7,21 +7,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_float(op) \

View File

@@ -22,7 +22,7 @@ template <typename T, typename U>
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[0]);
}
@@ -32,7 +32,7 @@ template <typename T, typename U>
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
}
@@ -42,7 +42,7 @@ template <typename T, typename U, typename IdxT = int64_t>
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
@@ -53,7 +53,7 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -65,7 +65,7 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
IdxT dst_idx =
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
@@ -80,7 +80,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc<int64_t, IdxT>(
auto src_idx = elem_to_loc<IdxT>(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
IdxT dst_idx =
@@ -104,8 +104,8 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -116,8 +116,8 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -128,8 +128,8 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -142,7 +142,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z},
src_shape,
src_strides,

View File

@@ -9,7 +9,7 @@ METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant int64_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
@@ -27,7 +27,7 @@ METAL_FUNC void gather_impl(
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
idx_loc += indices.row_contiguous[i]
? index.y
: elem_to_loc<size_t, LocT>(
: elem_to_loc<LocT>(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
@@ -39,7 +39,7 @@ METAL_FUNC void gather_impl(
}
auto src_offset =
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
elem_to_loc<LocT>(index.z, slice_sizes, src_strides, src_ndim);
LocT out_idx = index.z;
if (IDX_NDIM == 1) {

View File

@@ -436,9 +436,9 @@ template <
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const constant int64_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
@@ -486,31 +486,21 @@ template <
simd_lid);
}
#define instantiate_gemv_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
"_tm" #tm "_tn" #tn "_nc" #nc \
"_axpby" #axpby)]] [[kernel]] void \
gemv<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
instantiate_kernel( \
"gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc "_axpby" #axpby, \
gemv, \
itype, \
bm, \
bn, \
sm, \
sn, \
tm, \
tn, \
nc, \
axpby)
// clang-format off
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
@@ -549,13 +539,13 @@ template <
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int64_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int64_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant int64_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -571,8 +561,8 @@ template <
// Update batch offsets
if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
const constant auto* veci_bstrides = index_batch_strides;
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
@@ -619,37 +609,14 @@ template <
simd_lid);
}
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
gemv_gather<itype, bm, bn, sm, sn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_bs_blocks(name, itype) \
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
instantiate_kernel( \
"gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn, \
gemv_gather, itype, bm, bn, sm, sn, tm, tn)
#define instantiate_gemv_bs_blocks(name, itype) \
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
@@ -684,9 +651,9 @@ template <
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const constant int64_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
@@ -734,33 +701,14 @@ template <
simd_lid);
}
#define instantiate_gemv_t_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
"_tm" #tm "_tn" #tn "_nc" #nc \
"_axpby" #axpby)]] [[kernel]] void \
gemv_t<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
instantiate_kernel( \
"gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
"_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \
gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
@@ -800,13 +748,13 @@ template <
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int64_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int64_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant int64_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -822,8 +770,8 @@ template <
// Update batch offsets
if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
const constant auto* veci_bstrides = index_batch_strides;
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
@@ -870,36 +818,14 @@ template <
simd_lid);
}
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
gemv_t_gather<itype, bm, bn, sm, sn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t_bs_helper( \
nm, itype, bm, bn, sm, sn, tm, tn) \
instantiate_kernel( \
"gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn, \
gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)
#define instantiate_gemv_t_bs_blocks(name, itype) \
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \

View File

@@ -642,13 +642,13 @@ template <
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
const constant int64_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -673,8 +673,8 @@ template <
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
const constant auto* mask_strides_mat = mask_batch_strides;
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
@@ -742,13 +742,13 @@ template <
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
const constant int64_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -773,8 +773,8 @@ template <
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
const constant auto* mask_strides_mat = mask_batch_strides;
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);

View File

@@ -10,29 +10,11 @@
#define instantiate_gemv_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
template [[host_name("gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
gemv_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const device outm_t* out_mask [[buffer(20)]], \
const device opm_t* mat_mask [[buffer(21)]], \
const device opm_t* vec_mask [[buffer(22)]], \
const constant int* mask_strides [[buffer(23)]], \
const constant size_t* mask_batch_strides [[buffer(24)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
instantiate_kernel( \
"gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc, \
gemv_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
@@ -61,29 +43,11 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t);
#define instantiate_gemv_t_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
gemv_t_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const device outm_t* out_mask [[buffer(20)]], \
const device opm_t* mat_mask [[buffer(21)]], \
const device opm_t* vec_mask [[buffer(22)]], \
const constant int* mask_strides [[buffer(23)]], \
const constant size_t* mask_batch_strides [[buffer(24)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
instantiate_kernel( \
"gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_tn" #tn "_nc" #nc, \
gemv_t_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \

View File

@@ -8,7 +8,7 @@ template <typename IdxT, int NIDX>
struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant size_t* strides;
const constant int64_t* strides;
const constant bool* row_contiguous;
const int ndim;
};

View File

@@ -854,15 +854,17 @@ METAL_FUNC void qvm_impl(
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
constexpr int tn = 32 / pack_factor;
constexpr int block_size = SIMD_SIZE;
const device uint8_t* ws = (const device uint8_t*)w;
using W_T =
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
const device W_T* ws = (const device W_T*)w;
typedef float U;
typedef struct {
uint8_t wi[tn * bytes_per_pack];
W_T wi[tn * bytes_per_pack];
} vec_w;
thread vec_w w_local;
@@ -1217,12 +1219,12 @@ METAL_FUNC void adjust_matrix_offsets(
int output_stride,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant size_t* x_strides,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant size_t* w_strides,
const constant size_t* s_strides,
const constant size_t* b_strides,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int64_t* b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
@@ -1246,6 +1248,41 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device T*& scales,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
uint32_t w_idx = tid.z;
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
} else {
ulong2 idx = elem_to_loc_broadcast(
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
}
y += tid.z * output_stride;
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
@@ -1258,16 +1295,16 @@ METAL_FUNC void adjust_matrix_offsets(
int output_stride,
const constant int& batch_ndims,
const constant int* batch_shape,
const constant size_t* lhs_strides,
const constant size_t* rhs_strides,
const constant int64_t* lhs_strides,
const constant int64_t* rhs_strides,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant size_t* x_strides,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant size_t* w_strides,
const constant size_t* s_strides,
const constant size_t* b_strides,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int64_t* b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx;
@@ -1311,12 +1348,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
@@ -1362,12 +1399,12 @@ template <typename T, int group_size, int bits, bool batched>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1413,12 +1450,12 @@ template <typename T, const int group_size, const int bits, bool batched>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1464,12 +1501,12 @@ template <typename T, const int group_size, const int bits, bool batched>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1515,12 +1552,12 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& final_block_size [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1579,12 +1616,12 @@ template <
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1637,12 +1674,12 @@ template <
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1689,18 +1726,18 @@ template <typename T, int group_size, int bits>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1750,18 +1787,18 @@ template <typename T, int group_size, int bits>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1811,18 +1848,18 @@ template <typename T, int group_size, int bits>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1880,18 +1917,18 @@ template <
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* rhs_strides [[buffer(21)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1947,18 +1984,18 @@ template <
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* rhs_strides [[buffer(21)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -2147,3 +2184,666 @@ template <typename T, const int group_size, const int bits>
}
}
}
template <typename T, typename U, int bits>
inline vec<U, 4> partial_qdot_vec(const thread U* x, vec<uint32_t, 4> w) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
vec<U, 4> accum = 0;
if (bits == 2) {
for (int i = 0; i < 4; i++) {
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
for (int j = 0; j < 4; j++) {
accum[i] +=
(x[4 * j + 0] * (ws[j] & 0x03) + x[4 * j + 1] * (ws[j] & 0x0c) +
x[4 * j + 2] * (ws[j] & 0x30) + x[4 * j + 3] * (ws[j] & 0xc0));
}
}
}
else if (bits == 4) {
for (int i = 0; i < 4; i++) {
auto ws = as_type<vec<uint16_t, 2>>(w[i]);
for (int j = 0; j < 2; j++) {
accum[i] +=
(x[4 * j + 0] * (ws[j] & 0x000f) + x[4 * j + 1] * (ws[j] & 0x00f0) +
x[4 * j + 2] * (ws[j] & 0x0f00) + x[4 * j + 3] * (ws[j] & 0xf000));
}
}
}
else if (bits == 8) {
for (int i = 0; i < 4; i++) {
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
for (int j = 0; j < 4; j++) {
accum[i] += x[j] * ws[j];
}
}
}
return accum;
}
template <typename T, int group_size, int bits>
METAL_FUNC void affine_packed_qmv_fast_impl(
const device vec<uint32_t, 4>* w,
const device vec<T, 4>* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int packs_per_thread = (bits == 2) ? 1 : 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = 32 / bits;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
typedef float U;
thread U x_thread[values_per_thread];
vec<U, results_per_simdgroup> result = 0;
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size * 2 / group_size;
const int w_row = tid.x * num_simdgroups + simd_gid;
const int out_row = w_row * results_per_simdgroup;
w += w_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += w_row * in_vec_size_g + 2 * (simd_lid / scale_step_per_thread);
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
// Load the input vector
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
// Load the scales and biases
vec<T, 4> s = scales[0];
vec<T, 4> b = scales[1];
// Load the weights and perform the partial dot product
vec<U, 4> accum = 0;
for (int pack = 0; pack < packs_per_thread; pack++) {
accum +=
partial_qdot_vec<T, U, bits>(x_thread + pack * pack_factor, w[pack]);
}
// Finalize the dot product and accumulate it
for (int i = 0; i < 4; i++) {
result[i] += static_cast<U>(s[i]) * accum[i] + static_cast<U>(b[i]) * sum;
}
w += block_size / pack_factor;
scales += 2 * block_size / group_size;
x += block_size;
}
result = simd_sum(result);
if (simd_lid == 0) {
for (int row = 0; row < results_per_simdgroup; row++) {
y[row] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits, int results_per_simdgroup>
METAL_FUNC void affine_packed_byte_qmv_fast_impl(
const device uint8_t* w,
const device vec<T, 2 * results_per_simdgroup>* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = (bits == 3) ? 8 : 4;
;
constexpr int bytes_per_pack = 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
typedef float U;
thread U x_thread[values_per_thread];
vec<U, results_per_simdgroup> result = 0;
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int scales_row = tid.x * num_simdgroups + simd_gid;
const int out_row = scales_row * results_per_simdgroup;
w += out_row * in_vec_size_w + simd_lid * (packs_per_thread * bytes_per_pack);
scales += scales_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
// Load the input vector
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
// Load the scales and biases
vec<T, 2 * results_per_simdgroup> sb = scales[0];
// Load the weights and perform the partial dot product
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] += qdot<U, values_per_thread, bits>(
w + row * in_vec_size_w, x_thread, sb[row], sb[2 + row], sum);
}
w += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits>
[[kernel]] void affine_packed_qmv_fast(
const device vec<uint32_t, 4>* w [[buffer(0)]],
const device vec<T, 4>* scales [[buffer(1)]],
const device T* x [[buffer(2)]],
device T* y [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (bits & (bits - 1)) {
affine_packed_byte_qmv_fast_impl<T, group_size, bits, 2>(
(const device uint8_t*)w,
scales,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
} else {
affine_packed_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
}
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
short bits>
struct AffinePackedQuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
"The group size should be larger than the columns");
static_assert(
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
MLX_MTL_CONST short pack_factor = 32 / bits;
MLX_MTL_CONST short row_pack_factor = 4;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS * row_pack_factor / pack_factor;
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
MLX_MTL_CONST short TOTAL_INTS = BCOLS_PACKED * BROWS_PACKED;
MLX_MTL_CONST short n_reads =
(TOTAL_INTS < tgp_size) ? 1 : TOTAL_INTS / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
static_assert(
n_reads <= row_pack_factor,
"The loader only supports per thread reads <= row_pack_factor");
const int src_ld;
const int tile_stride;
short group_step_cnt;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
const short bii;
const short bjj;
const device uint32_t* src;
const device T* scales;
const device T* biases;
threadgroup T* dst;
AffinePackedQuantizedBlockLoader(
const device uint32_t* src_,
const device T* scales_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS_PACKED * src_ld),
group_step_cnt(0),
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
bii(bi * row_pack_factor + bj % row_pack_factor),
bjj(bj / row_pack_factor),
src(src_ + bi * src_ld * row_pack_factor / pack_factor + bj),
scales(
scales_ + bi * 2 * src_ld * row_pack_factor / group_size +
bj % row_pack_factor),
biases(scales + row_pack_factor),
dst(dst_ + bii * dst_ld + bjj * pack_factor) {}
void load_unsafe() const {
if (bits == 2 && BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
for (int i = 0; i < n_reads; i++) {
T scale = scales[i];
T bias = biases[i];
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
}
}
void load_safe(short2 src_tile_dim) const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
for (int i = 0; i < n_reads; i++) {
T scale = scales[i];
T bias = biases[i];
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
if (group_steps > 1) {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += group_stride;
biases += group_stride;
}
}
};
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
short bits>
struct AffineScalesPackedQuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
"The group size should be larger than the columns");
static_assert(
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4;
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
MLX_MTL_CONST short row_pack_factor = 2;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor;
MLX_MTL_CONST short n_reads =
(TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
const int src_ld;
const int tile_stride;
short group_step_cnt;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
const short bii;
const device uint8_t* src;
const device T* scales;
const device T* biases;
threadgroup T* dst;
AffineScalesPackedQuantizedBlockLoader(
const device uint32_t* src_,
const device T* scales_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(
reduction_dim ? BCOLS_PACKED * bytes_per_pack
: BROWS * src_ld * bytes_per_pack / pack_factor),
group_step_cnt(0),
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
bii(bi / row_pack_factor),
src(((const device uint8_t*)src_) +
bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack),
scales(
scales_ + bii * 2 * src_ld * row_pack_factor / group_size +
bi % row_pack_factor),
biases(scales + row_pack_factor),
dst(dst_ + bi * dst_ld + bj * pack_factor) {}
void load_unsafe() const {
if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) {
return;
}
T scale = *scales;
T bias = *biases;
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + bytes_per_pack * i),
scale,
bias,
dst + i * pack_factor);
}
}
void load_safe(short2 src_tile_dim) const {
if (TOTAL_READS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
for (int i = 0; i < n_reads; i++) {
T scale = scales[i];
T bias = biases[i];
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + bytes_per_pack * i * src_ld),
scale,
bias,
dst + i * dst_ld);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
if (group_steps > 1) {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += group_stride;
biases += group_stride;
}
}
};
template <
typename T,
const int group_size,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void affine_packed_qmm_t_impl(
const device uint32_t* w,
const device T* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
threadgroup T* Ws,
const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
(void)lid;
constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = 32 / bits;
constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2;
constexpr int BK_padded = (BK + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_fully_packed_t = AffinePackedQuantizedBlockLoader<
T,
BN,
BK,
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
bits>;
using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader<
T,
BN,
BK,
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
bits>;
using loader_w_t = typename ConditionalType<
power_of_2_bits,
loader_fully_packed_t,
loader_scales_packed_t>::type;
// Set the block
const int K_w =
(power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32;
const int K_g = K * 2 * row_pack_factor / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const int packed_y_col = tid.x * (BN / row_pack_factor);
x += y_row * K;
w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w;
scales += packed_y_col * K_g;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(w, scales, K, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
} else {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM || num_outs < BN) {
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
} else {
mma_op.store_result(y, N);
}
}
template <
typename T,
const int group_size,
const int bits,
const bool aligned_N,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void affine_packed_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* x [[buffer(2)]],
device T* y [[buffer(3)]],
const constant int& K [[buffer(4)]],
const constant int& N [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
}
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}

View File

@@ -60,6 +60,14 @@
bits, \
split_k)
#define instantiate_quantized_affine_packed(name, type, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
@@ -96,12 +104,20 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, true) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, false) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, true) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_affine_packed(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \

View File

@@ -71,7 +71,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
constant const uint& bytes_per_key,
constant const int& ndim,
constant const int* key_shape,
constant const size_t* key_strides,
constant const int64_t* key_strides,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x;

View File

@@ -53,32 +53,32 @@ instantiate_init_min_max(max, Max)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, size_t, dim) \
itype, otype, op, int64_t, dim) \
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, size_t, dim)
itype, otype, op, int64_t, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, uint, dim, bm, bn) \
itype, otype, op, int, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, size_t, dim, bm, bn)
itype, otype, op, int64_t, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, uint, dim, bm, bn) \
itype, otype, op, int, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, size_t, dim, bm, bn)
itype, otype, op, int64_t, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
@@ -95,18 +95,18 @@ instantiate_init_min_max(max, Max)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, size_t, dim)
itype, otype, op, int64_t, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, size_t, dim)
itype, otype, op, int64_t, dim)
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \
@@ -125,7 +125,7 @@ instantiate_init_min_max(max, Max)
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
#define instantiate_and_or(name, op) \
instantiate_reduce_functions(name, bool_, bool, bool, op) \
instantiate_reduce_functions(name, bool_, bool, bool, op) \
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
instantiate_reduce_functions(name, int64, int64_t, bool, op)

View File

@@ -5,12 +5,12 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int64_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
@@ -34,7 +34,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
bool safe = column + n_reads <= reduction_stride;
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
in += in_idx + column;
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
@@ -100,10 +100,10 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]],
@@ -116,7 +116,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
const device T* row;
IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
in += in_idx + lid.x;
U total = Op::init;
@@ -164,12 +164,12 @@ template <
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int64_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
@@ -197,7 +197,7 @@ template <
bool safe = column + n_reads <= reduction_stride;
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
in += in_idx + column;
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
@@ -303,12 +303,12 @@ template <
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int64_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]],
@@ -342,7 +342,7 @@ template <
IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
IdxT block_idx = full_idx / IdxT(out_size);
IdxT out_idx = full_idx % IdxT(out_size);
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
in += in_idx + column;
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);

View File

@@ -98,11 +98,11 @@ template <
METAL_FUNC void per_thread_row_reduce(
thread U totals[N_WRITES],
const device T* in,
const size_t row_idx,
const int64_t row_idx,
int blocks,
int extra,
const constant int* shape,
const constant size_t* strides,
const constant int64_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x) {
@@ -199,13 +199,13 @@ template <
[[kernel]] void row_reduce_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& row_size [[buffer(2)]],
const constant size_t& non_row_reductions [[buffer(3)]],
const constant int64_t& row_size [[buffer(2)]],
const constant int64_t& non_row_reductions [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint3 gid [[threadgroup_position_in_grid]],
@@ -225,7 +225,7 @@ template <
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
// Simple loop over non_row_reductions and reduce the row in the thread.
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_row_reductions; r++) {
row = in + loop.location();
@@ -238,7 +238,7 @@ template <
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
// thread reduces every 32nd row and then a simple simd reduce.
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
loop.next(simd_lane_id, reduce_shape, reduce_strides);
@@ -260,14 +260,14 @@ template <
typename T,
typename U,
typename Op,
typename IdxT = size_t,
typename IdxT = int64_t,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
[[kernel]] void row_reduce_simple(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int64_t& out_size [[buffer(3)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
@@ -314,13 +314,13 @@ template <
[[kernel]] void row_reduce_looped(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& row_size [[buffer(2)]],
const constant size_t& non_row_reductions [[buffer(3)]],
const constant int64_t& row_size [[buffer(2)]],
const constant int64_t& non_row_reductions [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
@@ -337,8 +337,7 @@ template <
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
// needs a small refactor.
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim) +
lid.x * N_READS;
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim) + lid.x * N_READS;
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row;

View File

@@ -20,28 +20,4 @@ using namespace metal;
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t)
// Quantized SDPA vector instantiations
#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
instantiate_kernel( \
"quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #group_size "_" #bits, \
quant_sdpa_vector_2pass_1, type, head_dim, group_size, bits)
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
instantiate_quant_sdpa_vector(type, heads, group_size, 8)
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
instantiate_quant_sdpa_vector_bits(type, heads, 128)
#define instantiate_quant_sdpa_vector_heads(type) \
instantiate_quant_sdpa_vector_group_size(type, 64) \
instantiate_quant_sdpa_vector_group_size(type, 128)
instantiate_quant_sdpa_vector_heads(float)
instantiate_quant_sdpa_vector_heads(bfloat16_t)
instantiate_quant_sdpa_vector_heads(float16_t)
// clang-format on

View File

@@ -16,11 +16,11 @@ METAL_FUNC void scatter_impl(
const device T* updates,
device mlx_atomic<T>* out,
const constant int* upd_shape,
const constant size_t* upd_strides,
const constant int64_t* upd_strides,
const constant size_t& upd_ndim,
const constant size_t& upd_size,
const constant int* out_shape,
const constant size_t* out_strides,
const constant int64_t* out_strides,
const constant size_t& out_ndim,
const constant int* axes,
const constant size_t& idx_size,
@@ -31,7 +31,7 @@ METAL_FUNC void scatter_impl(
auto ind_idx = gid.y * NWORK;
LocT out_offset = 0;
if (upd_size > 1) {
out_offset = elem_to_loc<size_t, LocT>(
out_offset = elem_to_loc<LocT>(
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
}
@@ -40,7 +40,7 @@ METAL_FUNC void scatter_impl(
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = indices.row_contiguous[i]
? ind_idx
: elem_to_loc<size_t, LocT>(
: elem_to_loc<LocT>(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
@@ -52,8 +52,7 @@ METAL_FUNC void scatter_impl(
}
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
if constexpr (!UPD_ROW_CONTIG) {
upd_idx =
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
upd_idx = elem_to_loc<LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
}
op.atomic_update(out, updates[upd_idx], out_idx);
}

View File

@@ -113,67 +113,6 @@ template <typename T, int D>
}
}
template <typename T, typename U, int elem_per_thread, int bits>
METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
U query_sum = 0;
if (bits == 4) {
for (int i = 0; i < elem_per_thread; i += 4) {
q[i] = scale * queries[i];
q[i + 1] = scale * queries[i + 1];
q[i + 2] = scale * queries[i + 2];
q[i + 3] = scale * queries[i + 3];
query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
q[i + 1] /= 16.0f;
q[i + 2] /= 256.0f;
q[i + 3] /= 4096.0f;
}
} else if (bits == 8) {
for (int i = 0; i < elem_per_thread; i++) {
q[i] = scale * queries[i];
query_sum += q[i];
}
}
return query_sum;
}
template <typename U, int elem_per_thread, int bits>
METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
if (bits == 4) {
auto ks = (const device uint16_t*)keys;
for (int i = 0; i < elem_per_thread / 4; i++) {
k[4 * i] = ks[i] & 0x000f;
k[4 * i + 1] = ks[i] & 0x00f0;
k[4 * i + 2] = ks[i] & 0x0f00;
k[4 * i + 3] = ks[i] & 0xf000;
}
} else if (bits == 8) {
auto ks = (const device uint8_t*)keys;
for (int i = 0; i < elem_per_thread; i++) {
k[i] = ks[i];
}
}
}
template <typename U, int elem_per_thread, int bits>
METAL_FUNC void load_values(
const device uint32_t* values,
thread U* v,
U value_scale,
U value_bias) {
auto vs = (const device uint8_t*)values;
if (bits == 4) {
U s[2] = {value_scale, value_scale / 16.0f};
for (int i = 0; i < elem_per_thread / 2; i++) {
v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
}
} else if (bits == 8) {
for (int i = 0; i < elem_per_thread; i++) {
v[i] = value_scale * vs[i] + value_bias;
}
}
}
template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]],
@@ -351,158 +290,3 @@ template <typename T, int D>
}
}
}
template <typename T, int D, int group_size, int bits>
[[kernel]] void quant_sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]],
const device uint32_t* keys [[buffer(1)]],
const device T* key_scales [[buffer(2)]],
const device T* key_biases [[buffer(3)]],
const device uint32_t* values [[buffer(4)]],
const device T* value_scales [[buffer(5)]],
const device T* value_biases [[buffer(6)]],
device float* out [[buffer(7)]],
device float* sums [[buffer(8)]],
device float* maxs [[buffer(9)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant size_t& k_group_stride,
const constant size_t& v_group_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int BN = 8;
constexpr int BD = 4;
constexpr int elem_per_thread = D / BD;
const int stride = BN * D;
constexpr int blocks = 32;
constexpr int pack_factor = 32 / bits;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U v[elem_per_thread];
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + quad_lid * elem_per_thread;
const int kv_idx =
(block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread;
const int packed_idx = kv_idx / pack_factor;
const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size;
const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size;
keys += kv_head_idx * k_stride + packed_idx;
key_scales += k_group_idx;
key_biases += k_group_idx;
values += kv_head_idx * v_stride + packed_idx;
value_scales += v_group_idx;
value_biases += v_group_idx;
out += head_idx * blocks * D + block_idx * D + quad_lid * elem_per_thread;
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;
// Read the query and 0 the output accumulator
U query_sum = load_queries<T, U, elem_per_thread, bits>(
queries, q, static_cast<U>(scale));
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}
U max_score = -1e9;
U sum_exp_score = 0;
// For each key
for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) {
// Read the key
load_keys<U, elem_per_thread, bits>(keys, k);
// Assume D % group_size == 0 so all the keys are in the same group
U key_scale = key_scales[0];
U key_bias = key_biases[0];
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = score * key_scale + query_sum * key_bias;
score = quad_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
U value_scale = value_scales[0];
U value_bias = value_biases[0];
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * v[i];
}
// Move the pointers to the next kv
keys += blocks * stride / pack_factor;
key_scales += blocks * stride / group_size;
key_biases += blocks * stride / group_size;
values += blocks * stride / pack_factor;
value_scales += blocks * stride / group_size;
value_biases += blocks * stride / group_size;
}
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
if (quad_lid == 0) {
max_scores[quad_gid] = max_score;
sum_exp_scores[quad_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);
// Write the sum and new max
if (simd_gid == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
outputs[quad_lid * BN + quad_gid] =
o[i] * fast::exp(max_scores[quad_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (quad_gid == 0) {
U output = outputs[quad_lid * BN];
for (int j = 1; j < BN; j++) {
output += outputs[quad_lid * BN + j];
}
out[i] = static_cast<T>(output);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}

View File

@@ -343,8 +343,8 @@ template <
const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const constant int* nc_shape [[buffer(6)]],
const constant size_t* in_nc_strides [[buffer(7)]],
const constant size_t* out_nc_strides [[buffer(8)]],
const constant int64_t* in_nc_strides [[buffer(7)]],
const constant int64_t* out_nc_strides [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
@@ -486,7 +486,7 @@ template <
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const constant int* nc_shape [[buffer(6)]],
const constant size_t* nc_strides [[buffer(7)]],
const constant int64_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<

View File

@@ -26,10 +26,10 @@ struct AttnParams {
int NQ_aligned; ///< Number of full query blocks
int NK_aligned; ///< Number of full key/value blocks
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
};
} // namespace steel

View File

@@ -14,9 +14,9 @@ struct MLXConvParams {
const int pad[NDIM]; // Input padding
const int kdil[NDIM]; // Kernel dilation
const int idil[NDIM]; // Input dilation
const size_t in_strides[NDIM + 2]; // In strides
const size_t wt_strides[NDIM + 2]; // Wt strides
const size_t out_strides[NDIM + 2]; // Out strides
const int64_t in_strides[NDIM + 2]; // In strides
const int64_t wt_strides[NDIM + 2]; // Wt strides
const int64_t out_strides[NDIM + 2]; // Out strides
const int groups; // Input channel groups
const bool flip;
};
@@ -59,4 +59,4 @@ struct Conv2DGeneralBaseInfo {
};
} // namespace steel
} // namespace mlx
} // namespace mlx

View File

@@ -38,12 +38,12 @@ template <
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
@@ -88,9 +88,8 @@ template <
uint32_t indx_A, indx_B, indx_C;
if (has_batch) {
const constant size_t* indx_A_bstrides = batch_strides;
const constant size_t* indx_B_bstrides =
batch_strides + params->batch_ndim;
const constant auto* indx_A_bstrides = batch_strides;
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
@@ -102,7 +101,7 @@ template <
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant size_t* indx_C_bstrides =
const constant auto* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
@@ -120,18 +119,18 @@ template <
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant size_t* batch_strides_A = operand_strides;
const constant auto* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
}
@@ -140,8 +139,8 @@ template <
// Handle regular batch
else {
if (has_batch) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
@@ -150,7 +149,7 @@ template <
B += batch_offsets.y;
if (use_out_source) {
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {

View File

@@ -7,26 +7,10 @@
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h"
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2), function_constant(use_out_source)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_kernel( \
"steel_gemm_fused_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \

View File

@@ -56,7 +56,7 @@ block_masked_gemm(
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant int64_t* batch_strides [[buffer(7)]],
const device out_mask_t* out_mask [[buffer(10)]],
const device op_mask_t* lhs_mask [[buffer(11)]],
const device op_mask_t* rhs_mask [[buffer(12)]],
@@ -104,7 +104,7 @@ block_masked_gemm(
return;
}
const constant size_t* mask_batch_strides =
const constant auto* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
if (params->batch_ndim > 1) {
@@ -116,8 +116,8 @@ block_masked_gemm(
}
if (has_operand_mask) {
const constant size_t* mask_strides_lhs = mask_batch_strides;
const constant size_t* mask_strides_rhs =
const constant auto* mask_strides_lhs = mask_batch_strides;
const constant auto* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
@@ -144,8 +144,8 @@ block_masked_gemm(
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
@@ -442,7 +442,7 @@ block_masked_gemm(
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant int64_t* batch_strides [[buffer(7)]],
const device bool* out_mask [[buffer(10)]],
const device bool* lhs_mask [[buffer(11)]],
const device bool* rhs_mask [[buffer(12)]],
@@ -476,15 +476,15 @@ block_masked_gemm(
}
if (params->batch_ndim > 1) {
const constant size_t* mask_batch_strides =
const constant auto* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
if (has_operand_mask) {
const constant size_t* mask_strides_lhs =
const constant auto* mask_strides_lhs =
mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs =
const constant auto* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
@@ -507,8 +507,8 @@ block_masked_gemm(
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);

View File

@@ -5,58 +5,45 @@
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
#define instantiate_gemm( \
outmaskname, \
outmasktype, \
opmaskname, \
opmasktype, \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_gemm_block_outmask_" #outmaskname \
"_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
block_masked_gemm< \
itype, \
outmasktype, \
opmasktype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const device outmasktype* out_mask [[buffer(10)]], \
const device opmasktype* lhs_mask [[buffer(11)]], \
const device opmasktype* rhs_mask [[buffer(12)]], \
const constant int* mask_strides [[buffer(13)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm( \
outmaskname, \
outmasktype, \
opmaskname, \
opmasktype, \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
instantiate_kernel( \
"steel_gemm_block_outmask_" #outmaskname \
"_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname, \
block_masked_gemm, \
itype, \
outmasktype, \
opmasktype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned)
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \

View File

@@ -5,46 +5,39 @@
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
gemm_splitk< \
itype, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device otype* C [[buffer(2)]], \
const constant GEMMSpiltKParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
instantiate_kernel( \
"steel_gemm_splitk_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname, \
gemm_splitk, \
itype, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned)
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
@@ -68,30 +61,13 @@ instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
#define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname \
"_" #aname)]] [[kernel]] void \
gemm_splitk_accum<atype, otype>( \
const device atype* C_split [[buffer(0)]], \
device otype* D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
"_axbpy")]] [[kernel]] void \
gemm_splitk_accum_axpby<atype, otype>( \
const device atype* C_split [[buffer(0)]], \
device otype* D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
const device otype* C [[buffer(5)]], \
const constant int& ldc [[buffer(6)]], \
const constant int& fdc [[buffer(7)]], \
const constant float& alpha [[buffer(8)]], \
const constant float& beta [[buffer(9)]], \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_accum(oname, otype, aname, atype) \
instantiate_kernel( \
"steel_gemm_splitk_accum_" #oname "_" #aname, \
gemm_splitk_accum, atype, otype) \
instantiate_kernel( \
"steel_gemm_splitk_accum_" #oname "_" #aname "_axbpy", \
gemm_splitk_accum_axpby, atype, otype) \
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);

View File

@@ -21,9 +21,9 @@ struct GEMMParams {
const int tiles_n;
const int tiles_m;
const size_t batch_stride_a;
const size_t batch_stride_b;
const size_t batch_stride_d;
const int64_t batch_stride_a;
const int64_t batch_stride_b;
const int64_t batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
@@ -54,7 +54,7 @@ struct GEMMAddMMParams {
const int ldc;
const int fdc;
const size_t batch_stride_c;
const int64_t batch_stride_c;
const float alpha;
const float beta;

View File

@@ -7,8 +7,8 @@
METAL_FUNC ulong2 elem_to_loc_broadcast(
uint elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
int ndim) {
ulong loc_a{0};
ulong loc_b{0};
@@ -24,9 +24,9 @@ METAL_FUNC ulong2 elem_to_loc_broadcast(
METAL_FUNC ulong3 elem_to_loc_broadcast(
uint elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
constant const int64_t* c_strides,
int ndim) {
ulong loc_a{0};
ulong loc_b{0};

View File

@@ -18,72 +18,72 @@ template <typename T, typename Op>
device T* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto offset = index.x + grid_dim.x * int64_t(index.y);
d[offset] = Op()(a[offset], b[offset], c[offset]);
}
template <typename T, typename Op, typename IdxT = size_t>
template <typename T, typename Op, typename IdxT = int64_t>
[[kernel]] void ternary_g_nd1(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
constant const int64_t& a_strides,
constant const int64_t& b_strides,
constant const int64_t& c_strides,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_1<size_t, IdxT>(index, c_strides);
auto a_idx = elem_to_loc_1<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_1<IdxT>(index, b_strides);
auto c_idx = elem_to_loc_1<IdxT>(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, typename IdxT = size_t>
template <typename T, typename Op, typename IdxT = int64_t>
[[kernel]] void ternary_g_nd2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
constant const int64_t a_strides[2],
constant const int64_t b_strides[2],
constant const int64_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides);
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
auto c_idx = elem_to_loc_2<IdxT>(index, c_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, typename IdxT = size_t>
template <typename T, typename Op, typename IdxT = int64_t>
[[kernel]] void ternary_g_nd3(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
constant const int64_t a_strides[3],
constant const int64_t b_strides[3],
constant const int64_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides);
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
auto c_idx = elem_to_loc_3<IdxT>(index, c_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, int N = 1, typename IdxT = size_t>
template <typename T, typename Op, int N = 1, typename IdxT = int64_t>
[[kernel]] void ternary_g(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
constant const int64_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {

View File

@@ -8,17 +8,17 @@
#include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \

View File

@@ -14,7 +14,7 @@ template <typename T, typename U, typename Op>
device U* out,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto offset = index.x + grid_dim.x * int64_t(index.y);
out[offset] = Op()(in[offset]);
}
@@ -23,16 +23,16 @@ template <
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
typename IdxT = int64_t>
[[kernel]] void unary_g(
device const T* in,
device U* out,
constant const int* in_shape,
constant const size_t* in_strides,
constant const int64_t* in_strides,
device const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc<size_t, IdxT>(
auto idx = elem_to_loc<IdxT>(
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
auto xshape = in_shape[ndim - 1];
IdxT xstride = in_strides[ndim - 1];

View File

@@ -9,19 +9,19 @@
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_float(op) \
#define instantiate_unary_float(op) \
instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_types(op) \
#define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \

View File

@@ -89,25 +89,11 @@ struct Limits<complex64_t> {
///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims
template <typename StrideT, typename IdxT = StrideT>
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc(
uint elem,
IdxT elem,
constant const int* shape,
constant const StrideT* strides,
int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc(
StrideT elem,
constant const int* shape,
constant const StrideT* strides,
constant const int64_t* strides,
int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
@@ -118,11 +104,11 @@ METAL_FUNC IdxT elem_to_loc(
}
// Non templated version to handle arbitrary dims
template <typename StrideT, typename IdxT = StrideT>
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc(
uint3 elem,
constant const int* shape,
constant const StrideT* strides,
constant const int64_t* strides,
int ndim) {
IdxT loc =
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
@@ -136,18 +122,18 @@ METAL_FUNC IdxT elem_to_loc(
///////////////////////////////////////////////////////////////////////////////
// Single Array with fixed N dims
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) {
return elem * IdxT(stride);
}
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) {
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
}
template <typename StrideT, typename IdxT = StrideT>
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) {
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
elem.z * IdxT(strides[0]);
}
@@ -155,12 +141,12 @@ METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims
template <typename StrideT, typename IdxT = StrideT>
template <typename IdxT = int64_t>
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const StrideT* a_strides,
constant const StrideT* b_strides,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
int ndim) {
vec<IdxT, 2> loc = {
IdxT(
@@ -178,18 +164,21 @@ METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
return loc;
}
template <typename IdxT = size_t>
template <typename IdxT = int64_t>
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
constant const int64_t* c_strides,
int ndim) {
vec<IdxT, 3> loc = {
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
IdxT(elem.y * IdxT(a_strides[ndim - 2])),
IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
IdxT(elem.y * IdxT(b_strides[ndim - 2])),
IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * IdxT(a_strides[d]);
@@ -213,7 +202,7 @@ struct LoopedElemToLoc {
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
void next(const constant int* shape, const constant size_t* strides) {
void next(const constant int* shape, const constant int64_t* strides) {
if (dim == 0) {
return;
}
@@ -226,7 +215,7 @@ struct LoopedElemToLoc {
}
}
void next(int n, const constant int* shape, const constant size_t* strides) {
void next(int n, const constant int* shape, const constant int64_t* strides) {
if (dim == 0) {
return;
}
@@ -262,19 +251,19 @@ struct LoopedElemToLoc<1, OffsetT, true> {
LoopedElemToLoc(int dim) : dim(dim) {}
void next(const constant int* shape, const constant size_t* strides) {
void next(const constant int* shape, const constant int64_t* strides) {
index++;
if (dim > 1) {
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
} else {
offset += OffsetT(strides[0]);
}
}
void next(int n, const constant int* shape, const constant size_t* strides) {
void next(int n, const constant int* shape, const constant int64_t* strides) {
index += n;
if (dim > 1) {
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
} else {
offset = index * OffsetT(strides[0]);
}
@@ -291,11 +280,11 @@ struct LoopedElemToLoc<1, OffsetT, false> {
LoopedElemToLoc(int) {}
void next(const constant int*, const constant size_t* strides) {
void next(const constant int*, const constant int64_t* strides) {
offset += OffsetT(strides[0]);
}
void next(int n, const constant int*, const constant size_t* strides) {
void next(int n, const constant int*, const constant int64_t* strides) {
offset += n * OffsetT(strides[0]);
}
@@ -421,3 +410,14 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
return complex64_t(
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
}
// std::conditional is not included with Metal
template <bool condition, typename T, typename U>
struct ConditionalType {
using type = U;
};
template <typename T, typename U>
struct ConditionalType<true, T, U> {
using type = T;
};

View File

@@ -21,8 +21,8 @@ namespace {
inline auto collapse_batches(const array& a, const array& b) {
// Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
@@ -30,8 +30,8 @@ inline auto collapse_batches(const array& a, const array& b) {
throw std::runtime_error(msg.str());
}
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
@@ -50,9 +50,9 @@ inline auto collapse_batches(const array& a, const array& b) {
inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
@@ -60,9 +60,9 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
throw std::runtime_error(msg.str());
}
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
@@ -82,6 +82,25 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
std::tuple<bool, int64_t, array> check_transpose(
std::vector<array>& copies,
const Stream& s,
const array& arr,
bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
};
} // namespace
///////////////////////////////////////////////////////////////////////////////
@@ -180,11 +199,11 @@ void steel_matmul_regular(
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<int> batch_shape,
std::vector<size_t> batch_strides,
size_t A_batch_stride,
size_t B_batch_stride,
size_t matrix_stride_out,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
std::vector<array>& copies) {
using namespace mlx::steel;
@@ -268,9 +287,9 @@ void steel_matmul_regular(
/* const int ldd = */ ldd,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ A_batch_stride,
/* const size_t batch_stride_b = */ B_batch_stride,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int64_t batch_stride_a = */ A_batch_stride,
/* const int64_t batch_stride_b = */ B_batch_stride,
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -314,9 +333,9 @@ void steel_matmul(
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
std::vector<int> batch_shape /* = {} */,
std::vector<size_t> A_batch_stride /* = {} */,
std::vector<size_t> B_batch_stride /* = {} */) {
Shape batch_shape /* = {} */,
Strides A_batch_stride /* = {} */,
Strides B_batch_stride /* = {} */) {
using namespace mlx::steel;
if (batch_shape.empty()) {
@@ -447,7 +466,7 @@ void steel_matmul(
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
std::vector<size_t> batch_strides = A_batch_stride;
auto batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
@@ -505,24 +524,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, a_cols, a] = check_transpose(a_pre, M == 1);
auto [b_transposed, b_cols, b] = check_transpose(b_pre, N == 1);
auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
@@ -662,9 +665,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/* bool transpose_a = */ a_transposed,
/* bool transpose_b = */ b_transposed,
/* std::vector<array>& = */ copies,
/* std::vector<int> batch_shape = */ batch_shape,
/* std::vector<size_t> A_batch_stride = */ A_batch_stride,
/* std::vector<size_t> B_batch_stride = */ B_batch_stride);
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -691,24 +694,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
array c = c_pre;
int ldc = c.strides()[c.ndim() - 2];
@@ -723,7 +710,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
collapse_batches(a, b, c);
size_t matrix_stride_out = size_t(M) * size_t(N);
int64_t matrix_stride_out = M * static_cast<int64_t>(N);
auto batch_size_out = out.size() / (matrix_stride_out);
// Collapse batches into M if needed
@@ -1044,9 +1031,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ A_batch_stride.back(),
/* const size_t batch_stride_b = */ B_batch_stride.back(),
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int64_t batch_stride_a = */ A_batch_stride.back(),
/* const int64_t batch_stride_b = */ B_batch_stride.back(),
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -1054,7 +1041,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
GEMMAddMMParams params{
/* const int ldc = */ ldc,
/* const int fdc = */ fdc,
/* const size_t batch_stride_c = */ C_batch_stride.back(),
/* const int64_t batch_stride_c = */ C_batch_stride.back(),
/* const float alpha = */ alpha_,
/* const float beta = */ beta_};
@@ -1065,7 +1052,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
std::vector<size_t> batch_strides = A_batch_stride;
Strides batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
batch_strides.insert(
@@ -1120,24 +1107,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
int lda = a_cols;
int ldb = b_cols;
@@ -1156,20 +1127,20 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return decltype(v){v.begin(), v.end() - 2};
};
std::vector<int> batch_shape{1};
std::vector<size_t> A_batch_stride{0};
std::vector<size_t> B_batch_stride{0};
std::vector<size_t> outmask_bstride{0};
std::vector<size_t> Amask_bstride{0};
std::vector<size_t> Bmask_bstride{0};
size_t A_batch_str = 0;
size_t B_batch_str = 0;
Shape batch_shape{1};
Strides A_batch_stride{0};
Strides B_batch_stride{0};
Strides outmask_bstride{0};
Strides Amask_bstride{0};
Strides Bmask_bstride{0};
int64_t A_batch_str = 0;
int64_t B_batch_str = 0;
std::vector<size_t> batch_strides;
Strides batch_strides;
if (out.ndim() > 2) {
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
std::vector<std::vector<size_t>> bstrides;
Shape bshape{out.shape().begin(), out.shape().end() - 2};
std::vector<Strides> bstrides;
for (auto& arr : inputs) {
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
@@ -1196,10 +1167,10 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
} else {
batch_strides = std::vector<size_t>(inputs.size(), 0);
batch_strides = Strides(inputs.size(), 0);
}
size_t matrix_stride_out = size_t(M) * N;
int64_t matrix_stride_out = static_cast<int64_t>(M) * N;
size_t batch_size_out = out.size() / (matrix_stride_out);
/////////////////////////////////////////////////////////////////////////////
@@ -1306,7 +1277,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Get mask params
std::vector<int> mask_strides;
std::vector<size_t> mask_batch_strides;
Strides mask_batch_strides;
if (has_out_mask) {
auto& out_mask = inputs[2];
@@ -1436,9 +1407,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ A_batch_str,
/* const size_t batch_stride_b = */ B_batch_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int64_t batch_stride_a = */ A_batch_str,
/* const int64_t batch_stride_b = */ B_batch_str,
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -1524,24 +1495,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
int lda = a_cols;
int ldb = b_cols;
@@ -1556,20 +1511,20 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
std::vector<size_t> batch_strides;
Shape batch_shape = get_batch_dims(out.shape());
Strides batch_strides;
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
batch_strides.insert(
batch_strides.end(),
rhs_indices.strides().begin(),
rhs_indices.strides().end());
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size();
@@ -1582,10 +1537,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int batch_ndim_B = b.ndim() - 2;
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
Shape batch_shape_A = get_batch_dims(a.shape());
Strides batch_strides_A = get_batch_dims(a.strides());
Shape batch_shape_B = get_batch_dims(b.shape());
Strides batch_strides_B = get_batch_dims(b.strides());
if (batch_ndim_A == 0) {
batch_shape_A = {1};
@@ -1597,7 +1552,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_strides_B = {0};
}
size_t matrix_stride_out = size_t(M) * N;
auto matrix_stride_out = static_cast<int64_t>(M) * N;
auto batch_size_out = out.size() / matrix_stride_out;
/////////////////////////////////////////////////////////////////////////////
@@ -1801,9 +1756,9 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ lhs_indices_str,
/* const size_t batch_stride_b = */ rhs_indices_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int64_t batch_stride_a = */ lhs_indices_str,
/* const int64_t batch_stride_b = */ rhs_indices_str,
/* const int64_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim};

View File

@@ -21,11 +21,11 @@ void steel_matmul_regular(
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<int> batch_shape,
std::vector<size_t> batch_strides,
size_t A_batch_stride,
size_t B_batch_stride,
size_t matrix_stride_out,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
std::vector<array>& copies);
void steel_matmul(
@@ -43,8 +43,8 @@ void steel_matmul(
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
std::vector<int> batch_shape = {},
std::vector<size_t> A_batch_stride = {},
std::vector<size_t> B_batch_stride = {});
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {});
} // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#include <sstream>
#include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
@@ -24,6 +25,25 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(step, 1);
}
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -101,10 +121,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Prepare the shapes, strides and axis arguments.
std::vector<size_t> in_strides = in.strides();
std::vector<int> shape = in.shape();
std::vector<size_t> out_strides = out.strides();
size_t axis_stride = in_strides[axis_];
auto in_strides = in.strides();
auto shape = in.shape();
auto out_strides = out.strides();
auto axis_stride = in_strides[axis_];
size_t axis_size = shape[axis_];
if (out_strides.size() == in_strides.size()) {
out_strides.erase(out_strides.begin() + axis_);
@@ -136,7 +156,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
if (ndim == 0) {
// Pass place holders so metal doesn't complain
int shape_ = 0;
size_t stride_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 2);
compute_encoder.set_bytes(stride_, 3);
compute_encoder.set_bytes(stride_, 4);
@@ -210,6 +230,18 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
copy_gpu(in, out, ctype);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = out,
@@ -304,27 +336,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape());
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General,
stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
@@ -366,22 +378,25 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_gpu_inplace<int64_t>(
copy_gpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides,
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* const Shape& data_shape = */ upd.shape(),
/* const Strides& i_strides = */ upd.strides(),
/* const Strides& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}

Some files were not shown because too many files have changed in this diff Show More