mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-09 08:34:18 +08:00
Compare commits
56 Commits
q-sdpa
...
packed-qua
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c02e14c264 | ||
![]() |
d75a509234 | ||
![]() |
14420949d2 | ||
![]() |
4847199ec6 | ||
![]() |
fb7be036af | ||
![]() |
410ccdbed5 | ||
![]() |
f5da489a3c | ||
![]() |
c2e6d58441 | ||
![]() |
17a1fa2f0b | ||
![]() |
fd161aa31f | ||
![]() |
bf6dc54110 | ||
![]() |
d7ed624502 | ||
![]() |
05cb54ae3f | ||
![]() |
cb358dbdda | ||
![]() |
e4b587819c | ||
![]() |
a06c968f4d | ||
![]() |
651c510940 | ||
![]() |
11ec07ff9d | ||
![]() |
bdd68bd893 | ||
![]() |
50f3535693 | ||
![]() |
9111999af3 | ||
![]() |
6bd28d246e | ||
![]() |
4d595a2a39 | ||
![]() |
3a21f61772 | ||
![]() |
4e1e9520e1 | ||
![]() |
0bf19037ca | ||
![]() |
f3dfa36a3a | ||
![]() |
4f9b60dd53 | ||
![]() |
f76a49e555 | ||
![]() |
310ad8d9db | ||
![]() |
56db268f47 | ||
![]() |
92ab6bdeb8 | ||
![]() |
0070e360a1 | ||
![]() |
9df8fed046 | ||
![]() |
a59fae040f | ||
![]() |
29a620cab2 | ||
![]() |
87d7a2520e | ||
![]() |
40c62c1321 | ||
![]() |
35b412c099 | ||
![]() |
d0f471cff7 | ||
![]() |
6f316b8bf5 | ||
![]() |
7c10c93a1f | ||
![]() |
d92ea094f1 | ||
![]() |
6ae5423b4a | ||
![]() |
9635cffdc8 | ||
![]() |
96986fb362 | ||
![]() |
3ceb341a75 | ||
![]() |
50fa705125 | ||
![]() |
69a2991614 | ||
![]() |
fd3377dd1f | ||
![]() |
d0b6cb0425 | ||
![]() |
95c4a2e3af | ||
![]() |
bc2a29f033 | ||
![]() |
3bb5b4a302 | ||
![]() |
fc88fd9097 | ||
![]() |
c5b0928c1f |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -76,6 +76,9 @@ build/
|
||||
*.out
|
||||
*.app
|
||||
|
||||
# Debug symbols
|
||||
*.pdb
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
.DS_Store
|
||||
|
@@ -1,13 +1,14 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.8
|
||||
rev: v19.1.4
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.8.0
|
||||
rev: 24.10.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
|
@@ -20,11 +20,12 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.21.0)
|
||||
set(MLX_VERSION 0.21.1)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -93,8 +94,7 @@ elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||
)
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
@@ -113,16 +113,52 @@ elseif(MLX_BUILD_METAL)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
# There is no prebuilt OpenBLAS distribution for MSVC.
|
||||
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
|
||||
endif()
|
||||
# Windows implementation of dlfcn.h APIs.
|
||||
FetchContent_Declare(
|
||||
dlfcn-win32
|
||||
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||
GIT_TAG v1.4.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(dlfcn-win32)
|
||||
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
||||
target_link_libraries(mlx PRIVATE dl)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if(ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||
# Download and build OpenBLAS from source code.
|
||||
FetchContent_Declare(
|
||||
openblas
|
||||
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
|
||||
GIT_TAG v0.3.28
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(BUILD_STATIC_LIBS ON) # link statically
|
||||
set(NOFORTRAN ON) # msvc has no fortran compiler
|
||||
FetchContent_MakeAvailable(openblas)
|
||||
target_link_libraries(mlx PRIVATE openblas)
|
||||
target_include_directories(
|
||||
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
|
||||
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
|
||||
else()
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
@@ -140,7 +176,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old
|
||||
# version of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
@@ -153,14 +189,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||
|
||||
if(WIN32)
|
||||
find_package(dlfcn-win32 REQUIRED)
|
||||
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
|
||||
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
|
||||
endif()
|
||||
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
|
@@ -5,35 +5,35 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_value_and_grad() {
|
||||
auto x = ones({200, 1000});
|
||||
eval(x);
|
||||
auto fn = [](array x) {
|
||||
auto x = mx::ones({200, 1000});
|
||||
mx::eval(x);
|
||||
auto fn = [](mx::array x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = log(exp(x));
|
||||
x = mx::log(mx::exp(x));
|
||||
}
|
||||
return sum(x);
|
||||
return mx::sum(x);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
auto independent_value_and_grad = [&]() {
|
||||
auto value = fn(x);
|
||||
auto dfdx = grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(independent_value_and_grad);
|
||||
|
||||
auto value_and_grad_fn = value_and_grad(fn);
|
||||
auto value_and_grad_fn = mx::value_and_grad(fn);
|
||||
auto combined_value_and_grad = [&]() {
|
||||
auto [value, dfdx] = value_and_grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(combined_value_and_grad);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_value_and_grad();
|
||||
}
|
||||
|
@@ -4,21 +4,21 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_add_op() {
|
||||
std::vector<int> sizes(1, 1);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
sizes.push_back(10 * sizes.back());
|
||||
}
|
||||
set_default_device(Device::cpu);
|
||||
set_default_device(mx::Device::cpu);
|
||||
for (auto size : sizes) {
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
std::cout << "Size " << size << std::endl;
|
||||
TIMEM("cpu", add, a, b, Device::cpu);
|
||||
TIMEM("gpu", add, a, b, Device::gpu);
|
||||
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
||||
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -6,105 +6,105 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_irregular_binary_ops_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
b = slice(b, {0}, {size}, {step});
|
||||
TIMEM("1D strided", add, a, b, device);
|
||||
TIMEM("1D strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
auto a = random::uniform({size, size});
|
||||
auto b = random::uniform({size, size});
|
||||
eval(a, b);
|
||||
TIMEM("2D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({size, size});
|
||||
auto b = mx::random::uniform({size, size});
|
||||
mx::eval(a, b);
|
||||
TIMEM("2D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b);
|
||||
eval(b);
|
||||
TIMEM("2D transpose", add, a, b, device);
|
||||
b = mx::transpose(b);
|
||||
mx::eval(b);
|
||||
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({size});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({size});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = reshape(b, {size, 1});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||
b = mx::reshape(b, {size, 1});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_3D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int d0 = 32;
|
||||
int d1 = 512;
|
||||
int d2 = 512;
|
||||
auto a = random::uniform({d0, d1, d2});
|
||||
auto b = random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({d0, d1, d2});
|
||||
auto b = mx::random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 2, 1});
|
||||
TIMEM("3D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 2, 1});
|
||||
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||
b = mx::random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
auto a = random::uniform(shape);
|
||||
auto b = random::uniform(shape);
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
TIMEM("4D regular", add, a, b, device);
|
||||
TIMEM("4D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
std::string om = "4D broadcast dims ";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = 1;
|
||||
b = random::uniform(shape);
|
||||
b = mx::random::uniform(shape);
|
||||
std::ostringstream msg;
|
||||
msg << om << i;
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
|
||||
for (int j = i + 1; j < shape.size(); ++j) {
|
||||
shape[j] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[j] = a.shape(j);
|
||||
|
||||
for (int k = j + 1; k < shape.size(); ++k) {
|
||||
shape[k] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j << ", " << k;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[k] = a.shape(k);
|
||||
}
|
||||
}
|
||||
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
|
||||
}
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
auto reshape_fn = [&shape, device](const array& a) {
|
||||
return reshape(a, shape, device);
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
|
||||
int size = 64;
|
||||
int d = 2 * size;
|
||||
|
||||
auto a = random::uniform({d, d, d});
|
||||
auto a = mx::random::uniform({d, d, d});
|
||||
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D contiguous", reshape_fn, a);
|
||||
|
||||
a = transpose(a);
|
||||
a = mx::transpose(a);
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose", reshape_fn, a);
|
||||
|
||||
a = transpose(a, {1, 2, 0});
|
||||
a = mx::transpose(a, {1, 2, 0});
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||
}
|
||||
|
||||
void time_irregular_astype_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto a = mx::random::uniform({size});
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
TIMEM("1D strided", astype, a, int32, device);
|
||||
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
|
||||
auto a = random::uniform(shape);
|
||||
TIMEM("2D regular", astype, a, int32, device);
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = transpose(a);
|
||||
TIMEM("2D transpose", astype, a, int32, device);
|
||||
a = mx::transpose(a);
|
||||
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 1) {
|
||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
||||
}
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_irregular_binary_ops_1D();
|
||||
time_irregular_binary_ops_2D();
|
||||
time_irregular_binary_ops_3D();
|
||||
|
@@ -3,20 +3,20 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_creation_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
||||
TIME(full_fp32);
|
||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
||||
TIME(zeros_fp32);
|
||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
||||
TIME(ones_fp32);
|
||||
|
||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
||||
TIME(arange_fp32);
|
||||
}
|
||||
|
||||
@@ -24,194 +24,196 @@ void time_type_conversions() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = zeros(shape, float32);
|
||||
eval(a);
|
||||
TIMEM("float32 to int32", astype, a, int32, device);
|
||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||
auto a = mx::zeros(shape, mx::float32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
|
||||
a = zeros(shape, int32);
|
||||
eval(a);
|
||||
TIMEM("int32 to float32", astype, a, float32, device);
|
||||
a = mx::zeros(shape, mx::int32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
||||
|
||||
a = zeros(shape, bool_);
|
||||
eval(a);
|
||||
TIMEM("bool to float32", astype, a, float32, device);
|
||||
TIMEM("bool to int32", astype, a, int32, device);
|
||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||
a = mx::zeros(shape, mx::bool_);
|
||||
mx::eval(a);
|
||||
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
||||
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
}
|
||||
|
||||
void time_random_generation() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
|
||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
||||
TIME(uniform);
|
||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
||||
TIME(normal);
|
||||
}
|
||||
|
||||
void time_unary_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = random::normal({M, N});
|
||||
eval(a);
|
||||
auto a = mx::random::normal({M, N});
|
||||
mx::eval(a);
|
||||
TIME(mlx::core::abs, a, device);
|
||||
TIME(negative, a, device);
|
||||
TIME(sign, a, device);
|
||||
TIME(square, a, device);
|
||||
TIME(mx::negative, a, device);
|
||||
TIME(mx::sign, a, device);
|
||||
TIME(mx::square, a, device);
|
||||
TIME(mlx::core::sqrt, a, device);
|
||||
TIME(rsqrt, a, device);
|
||||
TIME(mx::rsqrt, a, device);
|
||||
TIME(mlx::core::exp, a, device);
|
||||
|
||||
a = random::uniform({M, N});
|
||||
a = mx::random::uniform({M, N});
|
||||
TIME(mlx::core::log, a, device);
|
||||
}
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto condition = random::randint(0, 2, {M, N, K});
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
auto condition = mx::random::randint(0, 2, {M, N, K});
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
|
||||
TIME(add, a, b, device);
|
||||
TIME(subtract, a, b, device);
|
||||
TIME(multiply, a, b, device);
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
TIME(where, condition, a, b, device);
|
||||
TIME(mx::add, a, b, device);
|
||||
TIME(mx::subtract, a, b, device);
|
||||
TIME(mx::multiply, a, b, device);
|
||||
TIME(mx::divide, a, b, device);
|
||||
TIME(mx::maximum, a, b, device);
|
||||
TIME(mx::minimum, a, b, device);
|
||||
TIME(mx::where, condition, a, b, device);
|
||||
|
||||
condition = array({true});
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
TIMEM("vector-scalar", subtract, a, b, device);
|
||||
TIMEM("scalar-vector", subtract, b, a, device);
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
TIMEM("scalar-vector", where, condition, a, b, device);
|
||||
condition = mx::array({true});
|
||||
b = mx::random::uniform({1});
|
||||
mx::eval(b);
|
||||
TIMEM("scalar", mx::add, a, b, device);
|
||||
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
||||
TIMEM("scalar", mx::multiply, a, b, device);
|
||||
TIMEM("vector-scalar", mx::divide, a, b, device);
|
||||
TIMEM("scalar-vector", mx::divide, b, a, device);
|
||||
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
||||
|
||||
condition = broadcast_to(array({true}), {1000, 100});
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
||||
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
||||
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
mx::eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
int M = 50, N = 50, O = 50, P = 50;
|
||||
auto a = random::uniform({M, N, O, P});
|
||||
auto b = random::uniform({M, N, O, P});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIMEM("non-strided", add, a, b, device);
|
||||
a = transpose(a, {1, 0, 2, 3});
|
||||
b = transpose(b, {3, 2, 0, 1});
|
||||
eval(a, b);
|
||||
TIMEM("strided", add, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, O, P});
|
||||
auto b = mx::random::uniform({M, N, O, P});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIMEM("non-strided", mx::add, a, b, device);
|
||||
a = mx::transpose(a, {1, 0, 2, 3});
|
||||
b = mx::transpose(b, {3, 2, 0, 1});
|
||||
mx::eval(a, b);
|
||||
TIMEM("strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_comparisons() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(equal, a, b, device);
|
||||
TIME(greater, a, b, device);
|
||||
TIME(greater_equal, a, b, device);
|
||||
TIME(less, a, b, device);
|
||||
TIME(less_equal, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::equal, a, b, device);
|
||||
TIME(mx::greater, a, b, device);
|
||||
TIME(mx::greater_equal, a, b, device);
|
||||
TIME(mx::less, a, b, device);
|
||||
TIME(mx::less_equal, a, b, device);
|
||||
}
|
||||
|
||||
void time_matvec() {
|
||||
int M = 2000, N = 200;
|
||||
auto a = random::uniform({M, N});
|
||||
auto b = random::uniform({N});
|
||||
auto c = random::uniform({M});
|
||||
eval(a, b, c);
|
||||
auto matvec = [&]() { return matmul(a, b); };
|
||||
auto a = mx::random::uniform({M, N});
|
||||
auto b = mx::random::uniform({N});
|
||||
auto c = mx::random::uniform({M});
|
||||
mx::eval(a, b, c);
|
||||
auto matvec = [&]() { return mx::matmul(a, b); };
|
||||
TIME(matvec);
|
||||
|
||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
||||
TIME(matvec_transpose);
|
||||
}
|
||||
|
||||
void time_matmul() {
|
||||
int M = 1000, N = 1000, K = 1000;
|
||||
auto a = random::uniform({M, K});
|
||||
auto b = random::uniform({K, N});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(matmul, a, b, device);
|
||||
auto a = mx::random::uniform({M, K});
|
||||
auto b = mx::random::uniform({K, N});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::matmul, a, b, device);
|
||||
|
||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
||||
TIME(transpose_matmul);
|
||||
}
|
||||
|
||||
void time_reductions() {
|
||||
auto a = random::normal({10000, 1000});
|
||||
eval(a);
|
||||
auto sum_all = [&a]() { return sum(a, false); };
|
||||
auto a = mx::random::normal({10000, 1000});
|
||||
mx::eval(a);
|
||||
auto sum_all = [&a]() { return mx::sum(a, false); };
|
||||
TIME(sum_all);
|
||||
|
||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
||||
TIME(sum_along_0);
|
||||
|
||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
||||
TIME(sum_along_1);
|
||||
|
||||
auto prod_all = [&a]() { return prod(a, false); };
|
||||
auto prod_all = [&a]() { return mx::prod(a, false); };
|
||||
TIME(prod_all);
|
||||
|
||||
auto all_true = [&a]() { return all(a, false); };
|
||||
auto all_true = [&a]() { return mx::all(a, false); };
|
||||
TIME(all_true);
|
||||
|
||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
||||
TIME(all_along_0);
|
||||
|
||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
||||
TIME(all_along_1);
|
||||
|
||||
auto any_true = [&a]() { return any(a, false); };
|
||||
auto any_true = [&a]() { return mx::any(a, false); };
|
||||
TIME(any_true);
|
||||
|
||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
||||
TIME(argmin_along_0);
|
||||
|
||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
auto a = random::normal({1000, 768});
|
||||
eval(a);
|
||||
auto indices = random::randint(0, 1000, {256});
|
||||
eval(indices);
|
||||
auto a = mx::random::normal({1000, 768});
|
||||
mx::eval(a);
|
||||
auto indices = mx::random::randint(0, 1000, {256});
|
||||
mx::eval(indices);
|
||||
|
||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
||||
TIME(embedding_lookup);
|
||||
|
||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||
eval(indices);
|
||||
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
||||
mx::eval(indices);
|
||||
|
||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||
auto single_element_lookup = [&a, &indices]() {
|
||||
return mx::take(a, indices);
|
||||
};
|
||||
TIME(single_element_lookup);
|
||||
|
||||
indices = random::randint(0, 1000, {256});
|
||||
auto updates = random::normal({256, 1, 768});
|
||||
eval(indices, updates);
|
||||
indices = mx::random::randint(0, 1000, {256});
|
||||
auto updates = mx::random::normal({256, 1, 768});
|
||||
mx::eval(indices, updates);
|
||||
|
||||
auto embedding_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -223,10 +225,10 @@ void time_gather_scatter() {
|
||||
};
|
||||
TIME(embedding_add);
|
||||
|
||||
a = reshape(a, {-1});
|
||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = random::normal({256 * 768, 1});
|
||||
eval(a, indices, updates);
|
||||
a = mx::reshape(a, {-1});
|
||||
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = mx::random::normal({256 * 768, 1});
|
||||
mx::eval(a, indices, updates);
|
||||
|
||||
auto single_element_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -240,21 +242,21 @@ void time_gather_scatter() {
|
||||
}
|
||||
|
||||
void time_divmod() {
|
||||
auto a = random::normal({1000});
|
||||
auto b = random::normal({1000});
|
||||
eval({a, b});
|
||||
auto a = mx::random::normal({1000});
|
||||
auto b = mx::random::normal({1000});
|
||||
mx::eval({a, b});
|
||||
|
||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
||||
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
||||
TIME(divmod_fused);
|
||||
|
||||
auto divmod_separate = [&a, &b]() {
|
||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
||||
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
||||
};
|
||||
TIME(divmod_separate);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
time_type_conversions();
|
||||
time_unary_ops();
|
||||
|
74
benchmarks/python/packed_qmm_bench.py
Normal file
74
benchmarks/python/packed_qmm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
B = 1024
|
||||
D = 1024
|
||||
M = 4 * D
|
||||
group_size = 64
|
||||
bits = 4
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def qmm_(x, wq1, wq2, q_type):
|
||||
for i in range(loops):
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
*wq1,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
*wq2,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def affine_qmm(x, wq1, wq2):
|
||||
return qmm_(x, wq1, wq2, "affine")
|
||||
|
||||
|
||||
def affine_packed_qmm(x, wq1, wq2):
|
||||
return qmm_(x, wq1, wq2, "affine-packed")
|
||||
|
||||
|
||||
def time_qmm():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_qmm, x, wq1, wq2)
|
||||
|
||||
|
||||
def time_packed_qmm():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(
|
||||
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(
|
||||
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_packed_qmm, x, wq1, wq2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for b in [2, 4, 8]:
|
||||
bits = b
|
||||
print(f"Bits {bits}:")
|
||||
time_qmm()
|
||||
time_packed_qmm()
|
@@ -1,94 +1,58 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 32768
|
||||
L = 16384
|
||||
H = 32
|
||||
H_k = H // 4
|
||||
D = 128
|
||||
dtype = mx.float16
|
||||
bits = 8
|
||||
|
||||
loops = 20
|
||||
loops = 10
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
for _ in range(loops):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
ke = k[:, :, None, :, :]
|
||||
ve = v[:, :, None, :, :]
|
||||
s = q @ ke.transpose(0, 1, 2, 4, 3)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
q = p @ ve
|
||||
q = q.reshape(B, Hq, L, D)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
|
||||
for i in range(loops):
|
||||
q = _sdpa(q, k, v)
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
for _ in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
return q
|
||||
|
||||
|
||||
def quant_sdpa(q, k, v, bits=4):
|
||||
for _ in range(loops):
|
||||
q = mx.fast.quantized_scaled_dot_product_attention(
|
||||
q, *k, *v, scale=1.0, mask=None, bits=bits
|
||||
)
|
||||
return q
|
||||
|
||||
|
||||
def quant_attention(q, k, v, bits=4):
|
||||
for _ in range(loops):
|
||||
B, Hq, L, D = q.shape
|
||||
Hk = k[0].shape[1]
|
||||
|
||||
q = q.reshape((B, Hk, Hq // Hk, L, D))
|
||||
ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
|
||||
ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
|
||||
|
||||
scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits)
|
||||
q = q.reshape((B, Hq, L, D))
|
||||
return q
|
||||
|
||||
|
||||
def time_self_attention_primitives(q, k, v):
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa(q, k, v):
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_quant_sdpa(q, k, v, bits=4):
|
||||
time_fn(quant_sdpa, q, k, v, bits)
|
||||
|
||||
|
||||
def time_self_attention_quant_primitives(q, k, v, bits=4):
|
||||
time_fn(quant_attention, q, k, v, bits)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
|
||||
mx.eval(q, k, v)
|
||||
|
||||
k_quant = mx.quantize(k, bits=bits)
|
||||
v_quant = mx.quantize(v, bits=bits)
|
||||
mx.eval(k_quant, v_quant)
|
||||
|
||||
k = mx.dequantize(*k_quant, bits=bits)
|
||||
v = mx.dequantize(*v_quant, bits=bits)
|
||||
|
||||
time_self_attention_sdpa(q, k, v)
|
||||
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
|
||||
time_self_attention_primitives(q, k, v)
|
||||
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
|
@@ -420,8 +420,8 @@ element in the output.
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
// Convert linear indices to offsets in array
|
||||
@@ -438,24 +438,10 @@ each instantiation a unique host name so we can identify it.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
instantiate_kernel("axpby_general_float32", axpby_general, float)
|
||||
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
|
||||
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
|
||||
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
|
@@ -168,6 +168,7 @@ Operations
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
unflatten
|
||||
var
|
||||
view
|
||||
where
|
||||
|
@@ -4,19 +4,19 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
if (!mx::distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
auto global_group = mx::distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
mx::array x = mx::ones({10});
|
||||
mx::array out = mx::distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of linear regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.01;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples (design matrix)
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Noisy labels
|
||||
auto eps = 1e-2 * random::normal({num_examples});
|
||||
auto y = matmul(X, w_star) + eps;
|
||||
auto eps = 1e-2 * mx::random::normal({num_examples});
|
||||
auto y = mx::matmul(X, w_star) + eps;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto yhat = matmul(X, w);
|
||||
return (0.5f / num_examples) * sum(square(yhat - y));
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto yhat = mx::matmul(X, w);
|
||||
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
||||
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of logistic regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.1;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Labels
|
||||
auto y = matmul(X, w_star) > 0;
|
||||
auto y = mx::matmul(X, w_star) > 0;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto logits = matmul(X, w);
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto logits = mx::matmul(X, w);
|
||||
auto scale = (1.0f / num_examples);
|
||||
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
|
||||
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
|
||||
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
||||
<< throughput << " (it/s)." << std::endl;
|
||||
|
@@ -5,27 +5,27 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
// To use Metal debugging and profiling:
|
||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||
metal::start_capture("mlx_trace.gputrace");
|
||||
mx::metal::start_capture("mlx_trace.gputrace");
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
auto s2 = new_stream(mx::Device::gpu);
|
||||
auto s3 = new_stream(mx::Device::gpu);
|
||||
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
|
||||
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
|
||||
auto x = mx::add(a, a, s2);
|
||||
auto y = mx::add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
std::cout << mx::multiply(x, y) << std::endl;
|
||||
|
||||
metal::stop_capture();
|
||||
mx::metal::stop_capture();
|
||||
}
|
||||
|
@@ -5,11 +5,11 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void array_basics() {
|
||||
// Make a scalar array:
|
||||
array x(1.0);
|
||||
mx::array x(1.0);
|
||||
|
||||
// Get the value out of it:
|
||||
auto s = x.item<float>();
|
||||
@@ -29,31 +29,31 @@ void array_basics() {
|
||||
|
||||
// The datatype should be float32:
|
||||
auto dtype = x.dtype();
|
||||
assert(dtype == float32);
|
||||
assert(dtype == mx::float32);
|
||||
|
||||
// Specify the dtype when constructing the array:
|
||||
x = array(1, int32);
|
||||
assert(x.dtype() == int32);
|
||||
x = mx::array(1, mx::int32);
|
||||
assert(x.dtype() == mx::int32);
|
||||
x.item<int>(); // OK
|
||||
// x.item<float>(); // Undefined!
|
||||
|
||||
// Make a multidimensional array:
|
||||
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
// mlx is row-major by default so the first row of this array
|
||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||
|
||||
// Make an array of shape {2, 2} filled with ones:
|
||||
auto y = ones({2, 2});
|
||||
auto y = mx::ones({2, 2});
|
||||
|
||||
// Pointwise add x and y:
|
||||
auto z = add(x, y);
|
||||
auto z = mx::add(x, y);
|
||||
|
||||
// Same thing:
|
||||
z = x + y;
|
||||
|
||||
// mlx is lazy by default. At this point `z` only
|
||||
// has a shape and a type but no actual data:
|
||||
assert(z.dtype() == float32);
|
||||
assert(z.dtype() == mx::float32);
|
||||
assert(z.shape(0) == 2);
|
||||
assert(z.shape(1) == 2);
|
||||
|
||||
@@ -63,33 +63,33 @@ void array_basics() {
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
// all of its dependencies are recursively evaluated to produce the result.
|
||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||
eval(z);
|
||||
mx::eval(z);
|
||||
|
||||
// Of course the array can still be an input to other operations. You can even
|
||||
// call eval on the array again, this will just be a no-op:
|
||||
eval(z); // no-op
|
||||
// Of course the array can still be an input to other operations. You can
|
||||
// even call eval on the array again, this will just be a no-op:
|
||||
mx::eval(z); // no-op
|
||||
|
||||
// Some functions or methods on arrays implicitly evaluate them. For example
|
||||
// accessing a value in an array or printing the array implicitly evaluate it:
|
||||
z = ones({1});
|
||||
z = mx::ones({1});
|
||||
z.item<float>(); // implicit evaluation
|
||||
|
||||
z = ones({2, 2});
|
||||
z = mx::ones({2, 2});
|
||||
std::cout << z << std::endl; // implicit evaluation
|
||||
}
|
||||
|
||||
void automatic_differentiation() {
|
||||
auto fn = [](array x) { return square(x); };
|
||||
auto fn = [](mx::array x) { return mx::square(x); };
|
||||
|
||||
// Computing the derivative function of a function
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
// Call grad_fn on the input to get the derivative
|
||||
auto x = array(1.5);
|
||||
auto x = mx::array(1.5);
|
||||
auto dfdx = grad_fn(x);
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto d2fdx2 = grad(grad(fn))(x);
|
||||
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
|
||||
// d2fdx2 is 2
|
||||
}
|
||||
|
||||
|
@@ -19,7 +19,7 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
@@ -32,24 +32,24 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input mx::array x
|
||||
const mx::array& y, // Input mx::array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
: promote_types(promoted_dtype, mx::float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
auto x_casted = mx::astype(x, out_dtype, s);
|
||||
auto y_casted = mx::astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
@@ -57,12 +57,12 @@ array axpby(
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return array(
|
||||
return mx::array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
/* mx::Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<mx::Primitive> primitive = */
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -71,16 +71,16 @@ array axpby(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
@@ -94,8 +94,8 @@ void axpby_impl(
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
@@ -105,8 +105,8 @@ void axpby_impl(
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
@@ -114,14 +114,14 @@ void Axpby::eval(
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
@@ -136,9 +136,9 @@ void Axpby::eval(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
@@ -150,10 +150,10 @@ void axpby_impl_accelerate(
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
copy_inplace(y, out, mx::CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
@@ -175,15 +175,15 @@ void axpby_impl_accelerate(
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
if (out.dtype() == mx::float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
@@ -198,8 +198,8 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
@@ -213,8 +213,8 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
@@ -225,7 +225,7 @@ void Axpby::eval_gpu(
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
auto& d = mx::metal::device(s.device);
|
||||
|
||||
// Prepare to specialize based on contiguity
|
||||
bool contiguous_kernel =
|
||||
@@ -235,12 +235,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
@@ -302,8 +302,8 @@ void Axpby::eval_gpu(
|
||||
|
||||
/** Fail evaluation on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& out) {
|
||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||
}
|
||||
|
||||
@@ -314,9 +314,9 @@ void Axpby::eval_gpu(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> Axpby::jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
@@ -328,8 +328,8 @@ std::vector<array> Axpby::jvp(
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
auto scale_arr = mx::array(scale, tangents[0].dtype());
|
||||
return {mx::multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
@@ -339,24 +339,24 @@ std::vector<array> Axpby::jvp(
|
||||
}
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> Axpby::vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
const std::vector<mx::array>&) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
std::vector<mx::array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
auto scale_arr = mx::array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
}
|
||||
@@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@@ -5,7 +5,9 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
namespace mx = mlx::core;
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation
|
||||
@@ -18,22 +20,22 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input array x
|
||||
const mx::array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class Axpby : public Primitive {
|
||||
class Axpby : public mx::Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
explicit Axpby(mx::Stream stream, float alpha, float beta)
|
||||
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
@@ -42,23 +44,25 @@ class Axpby : public Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
const std::vector<mx::array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@@ -66,8 +70,8 @@ class Axpby : public Primitive {
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
@@ -76,14 +80,16 @@ class Axpby : public Primitive {
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
bool is_equivalent(const mx::Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
void eval(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@@ -12,8 +12,8 @@ template <typename T>
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
@@ -34,29 +34,14 @@ template <typename T>
|
||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||
}
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||
axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||
axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
// clang-format off
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \
|
||||
instantiate_kernel( \
|
||||
"axpby_contiguous_" #type_name, axpby_contiguous, type)
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
// clang-format on
|
||||
|
@@ -8,14 +8,12 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
&my_ext::axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
"alpha"_a,
|
||||
|
@@ -18,6 +18,16 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
if(MSVC)
|
||||
# Disable some MSVC warnings to speed up compilation.
|
||||
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
# Export symbols by default to behave like macOS/linux.
|
||||
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
else()
|
||||
|
@@ -31,7 +31,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
}
|
||||
|
||||
array::array(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
@@ -42,7 +42,7 @@ array::array(
|
||||
std::move(inputs))) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
std::vector<std::vector<int>> shapes,
|
||||
std::vector<Shape> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
@@ -74,11 +74,7 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter)
|
||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, deleter);
|
||||
}
|
||||
@@ -126,7 +122,7 @@ bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = size();
|
||||
@@ -139,9 +135,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
void array::set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
std::vector<size_t> strides,
|
||||
Strides strides,
|
||||
Flags flags,
|
||||
deleter_t d) {
|
||||
Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = data_size;
|
||||
@@ -151,7 +147,7 @@ void array::set_data(
|
||||
|
||||
void array::copy_shared_buffer(
|
||||
const array& other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@@ -170,7 +166,7 @@ void array::copy_shared_buffer(const array& other) {
|
||||
|
||||
void array::move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@@ -237,13 +233,13 @@ void array::ArrayDesc::init() {
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
|
51
mlx/array.h
51
mlx/array.h
@@ -15,7 +15,10 @@ namespace mlx::core {
|
||||
|
||||
// Forward declaration
|
||||
class Primitive;
|
||||
using deleter_t = std::function<void(allocator::Buffer)>;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using Shape = std::vector<int32_t>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
@@ -33,7 +36,7 @@ class array {
|
||||
template <typename It>
|
||||
array(
|
||||
It data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
@@ -49,15 +52,15 @@ class array {
|
||||
template <typename T>
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
allocator::Buffer data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter = allocator::free);
|
||||
Deleter deleter = allocator::free);
|
||||
|
||||
/** Assignment to rvalue does not compile. */
|
||||
array& operator=(const array& other) && = delete;
|
||||
@@ -96,7 +99,7 @@ class array {
|
||||
}
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const std::vector<int>& shape() const {
|
||||
const Shape& shape() const {
|
||||
return array_desc_->shape;
|
||||
}
|
||||
|
||||
@@ -105,12 +108,12 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
int shape(int dim) const {
|
||||
auto shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
/** The strides of the array. */
|
||||
const std::vector<size_t>& strides() const {
|
||||
const Strides& strides() const {
|
||||
return array_desc_->strides;
|
||||
}
|
||||
|
||||
@@ -119,7 +122,7 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
auto strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
@@ -184,13 +187,13 @@ class array {
|
||||
*/
|
||||
|
||||
array(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
std::vector<std::vector<int>> shapes,
|
||||
std::vector<Shape> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
@@ -207,8 +210,8 @@ class array {
|
||||
|
||||
struct Data {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
Deleter d;
|
||||
Data(allocator::Buffer buffer, Deleter d = allocator::free)
|
||||
: buffer(buffer), d(d) {}
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
@@ -397,18 +400,18 @@ class array {
|
||||
// Check if the array is a tracer array
|
||||
bool is_tracer() const;
|
||||
|
||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
|
||||
|
||||
void set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
std::vector<size_t> strides,
|
||||
Strides strides,
|
||||
Flags flags,
|
||||
deleter_t d = allocator::free);
|
||||
Deleter d = allocator::free);
|
||||
|
||||
void copy_shared_buffer(
|
||||
const array& other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
@@ -417,7 +420,7 @@ class array {
|
||||
|
||||
void move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
@@ -436,8 +439,8 @@ class array {
|
||||
void init(const It src);
|
||||
|
||||
struct ArrayDesc {
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::shared_ptr<Primitive> primitive;
|
||||
@@ -471,10 +474,10 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||
explicit ArrayDesc(Shape shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
@@ -502,7 +505,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
template <typename It>
|
||||
array::array(
|
||||
It data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
init(data);
|
||||
@@ -521,7 +524,7 @@ array::array(
|
||||
template <typename T>
|
||||
array::array(
|
||||
std::initializer_list<T> data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
|
@@ -43,6 +43,7 @@ DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
@@ -65,7 +66,6 @@ DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
@@ -76,6 +76,7 @@ DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
|
@@ -5,13 +5,21 @@ else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
set(SHELL_EXT ps1)
|
||||
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
|
||||
else()
|
||||
set(SHELL_EXT sh)
|
||||
set(SHELL_CMD /bin/bash)
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT}
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
|
@@ -13,8 +13,8 @@ template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
std::vector<size_t> strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
|
@@ -178,10 +178,10 @@ void binary_op_dims(
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@@ -212,10 +212,10 @@ void binary_op_dispatch_dims(
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides) {
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
@@ -258,10 +258,10 @@ void binary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
|
||||
size_t stride = out_strides[dim - 4];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
||||
auto stride = out_strides[dim - 4];
|
||||
for (int64_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
@@ -327,7 +327,7 @@ void binary_op(
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
@@ -337,7 +337,7 @@ void binary_op(
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
|
@@ -16,10 +16,10 @@ void binary_op_dims(
|
||||
U* out_a,
|
||||
U* out_b,
|
||||
Op op,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@@ -96,9 +96,9 @@ void binary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
||||
size_t stride = out_strides[ndim - 3];
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
|
@@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
@@ -85,6 +85,16 @@ void Depends::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto strides = in.strides();
|
||||
for (auto ax : axes_) {
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
@@ -141,9 +151,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
@@ -151,8 +159,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
return {false, Strides(out.ndim(), 0)};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
@@ -160,7 +167,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
Strides out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
@@ -181,9 +188,9 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
@@ -249,16 +256,18 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
Strides strides;
|
||||
for (int i = 0, j = 0; i < in.ndim(); ++i) {
|
||||
if (j < axes_.size() && i == axes_[j]) {
|
||||
j++;
|
||||
} else {
|
||||
strides.push_back(in.strides(i));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -268,7 +277,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
@@ -285,8 +294,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
|
@@ -165,7 +165,7 @@ void compiled_allocate_outputs(
|
||||
bool move_buffers /* = false */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
Strides strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
|
@@ -746,9 +746,9 @@ void explicit_gemm_conv_1D_cpu(
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {N, oH, wH, C};
|
||||
Shape strided_shape = {N, oH, wH, C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
Strides strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[1],
|
||||
@@ -865,9 +865,9 @@ void explicit_gemm_conv_2D_cpu(
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
|
||||
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
Strides strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[2] * wt_strides[1],
|
||||
@@ -974,7 +974,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
|
||||
Shape strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
@@ -984,7 +984,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
|
||||
Strides strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
@@ -1000,7 +1000,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N, C};
|
||||
Shape strided_reshape = {N, C};
|
||||
for (const auto& o : oDim) {
|
||||
strided_reshape[0] *= o;
|
||||
}
|
||||
|
@@ -26,13 +26,13 @@ void copy_vector(const array& src, array& dst) {
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename StrideT, int D>
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
inline void copy_dims(
|
||||
const SrcT* src,
|
||||
DstT* dst,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
const Shape& shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int axis) {
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
@@ -40,7 +40,7 @@ inline void copy_dims(
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
copy_dims<SrcT, DstT, StrideT, D - 1>(
|
||||
copy_dims<SrcT, DstT, D - 1>(
|
||||
src, dst, shape, i_strides, o_strides, axis + 1);
|
||||
} else {
|
||||
*dst = static_cast<DstT>(*src);
|
||||
@@ -50,13 +50,13 @@ inline void copy_dims(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename StrideT>
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
if (data_shape.empty()) {
|
||||
@@ -65,30 +65,30 @@ void copy_general_general(
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
|
||||
auto [shape, strides] =
|
||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
int ndim = shape.size();
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, StrideT, 1>(
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, StrideT, 2>(
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
}
|
||||
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
|
||||
StrideT stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
|
||||
for (StrideT elem = 0; elem < src.size(); elem += stride) {
|
||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||
auto stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (int64_t elem = 0; elem < src.size(); elem += stride) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
shape,
|
||||
@@ -102,37 +102,37 @@ void copy_general_general(
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT, size_t>(
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename StrideT>
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>&,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides&,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
copy_general_general<SrcT, DstT, StrideT>(
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
data_shape,
|
||||
i_strides,
|
||||
make_contiguous_strides<StrideT>(data_shape),
|
||||
make_contiguous_strides(data_shape),
|
||||
i_offset,
|
||||
o_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT, size_t>(
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
make_contiguous_strides<size_t>(src.shape()),
|
||||
make_contiguous_strides(src.shape()),
|
||||
0,
|
||||
0);
|
||||
}
|
||||
@@ -282,13 +282,12 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
copy_inplace(src, dst, ctype);
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
@@ -311,24 +310,4 @@ void copy_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
template void copy_inplace<size_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<size_t>& i_strides,
|
||||
const std::vector<size_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
template void copy_inplace<int64_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<int64_t>& i_strides,
|
||||
const std::vector<int64_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -26,13 +26,12 @@ enum class CopyType {
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
@@ -57,6 +57,7 @@ DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
@@ -86,7 +87,6 @@ DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
@@ -101,6 +101,7 @@ DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
@@ -130,7 +131,7 @@ inline void matmul_common_general(
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
@@ -32,7 +32,7 @@ void gather(
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& slice_sizes) {
|
||||
const Shape& slice_sizes) {
|
||||
// If the array is row contiguous then we can do a contiguous copy given
|
||||
// two conditions on the slice size:
|
||||
// - Any number of leading ones in the slice sizes are allowed
|
||||
@@ -80,11 +80,10 @@ void gather(
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> src_it;
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = std::move(
|
||||
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
}
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
@@ -119,7 +118,7 @@ void dispatch_gather(
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& size) {
|
||||
const Shape& size) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather<bool, IdxT>(src, inds, out, axes, size);
|
||||
@@ -223,16 +222,16 @@ void scatter(
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
|
||||
std::vector<int> update_shape(
|
||||
Shape update_shape(
|
||||
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||
size_t update_size = 1;
|
||||
for (auto us : update_shape) {
|
||||
update_size *= us;
|
||||
}
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> update_it(updates);
|
||||
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator update_it(updates);
|
||||
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
|
@@ -2,6 +2,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Required for Visual Studio.
|
||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
||||
#ifdef _MSC_VER
|
||||
#include <complex>
|
||||
#define LAPACK_COMPLEX_CUSTOM
|
||||
#define lapack_complex_float std::complex<float>
|
||||
#define lapack_complex_double std::complex<double>
|
||||
#endif
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
|
38
mlx/backend/common/make_compiled_preamble.ps1
Normal file
38
mlx/backend/common/make_compiled_preamble.ps1
Normal file
@@ -0,0 +1,38 @@
|
||||
# This script generates a C++ function that provides the CPU
|
||||
# code for use with kernel generation.
|
||||
#
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
$OUTPUT_FILE = $args[0]
|
||||
$CL = $args[1]
|
||||
$SRCDIR = $args[2]
|
||||
|
||||
# Get command result as array.
|
||||
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
|
||||
# Remove empty lines.
|
||||
# Otherwise there will be too much empty lines making the result unreadable.
|
||||
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
|
||||
# Concatenate to string.
|
||||
$CONTENT = $CONTENT -join '`n'
|
||||
|
||||
# Append extra content.
|
||||
$CONTENT = @"
|
||||
$($CONTENT)
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::detail;
|
||||
"@
|
||||
|
||||
# Convert each char to ASCII code.
|
||||
# Unlike the unix script that outputs string literal directly, the output from
|
||||
# MSVC is way too large to be embedded as string and compilation will fail, so
|
||||
# we store it as static array instead.
|
||||
$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'
|
||||
|
||||
$OUTPUT = @"
|
||||
const char* get_kernel_preamble() {
|
||||
static char preamble[] = { $CHARCODES };
|
||||
return preamble;
|
||||
}
|
||||
"@
|
||||
|
||||
Set-Content -Path $OUTPUT_FILE -Value $OUTPUT
|
@@ -10,15 +10,16 @@ OUTPUT_FILE=$1
|
||||
GCC=$2
|
||||
SRCDIR=$3
|
||||
CLANG=$4
|
||||
ARCH=$5
|
||||
|
||||
if [ "$CLANG" = "TRUE" ]; then
|
||||
read -r -d '' INCLUDES <<- EOM
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
EOM
|
||||
CC_FLAGS=""
|
||||
CC_FLAGS="-arch ${ARCH}"
|
||||
else
|
||||
CC_FLAGS="-std=c++17"
|
||||
fi
|
||||
|
@@ -19,10 +19,10 @@ inline void mask_matrix(
|
||||
int block_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str,
|
||||
const int64_t X_data_str,
|
||||
const int64_t Y_data_str,
|
||||
const int64_t X_mask_str,
|
||||
const int64_t Y_mask_str,
|
||||
const size_t mask_offset) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
@@ -84,7 +84,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
@@ -117,13 +117,13 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
size_t mask_offset = elem_to_loc(
|
||||
auto mask_offset = elem_to_loc(
|
||||
mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
|
||||
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
auto X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
|
||||
if (mask.dtype() == bool_) {
|
||||
return mask_matrix(
|
||||
@@ -230,7 +230,7 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
@@ -262,13 +262,13 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
auto batch_shape = get_batch_dims(out.shape());
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
auto batch_shape_A = get_batch_dims(a.shape());
|
||||
auto batch_strides_A = get_batch_dims(a.strides());
|
||||
auto batch_shape_B = get_batch_dims(b.shape());
|
||||
auto batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
@@ -500,7 +500,12 @@ struct Equal {
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y || (std::isnan(x) && std::isnan(y));
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
// isnan always returns false for integers, and MSVC refuses to compile.
|
||||
return x == y;
|
||||
} else {
|
||||
return x == y || (std::isnan(x) && std::isnan(y));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -19,6 +19,16 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -258,6 +268,14 @@ void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -417,18 +435,8 @@ void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -498,34 +506,17 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General);
|
||||
} else {
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
Strides ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -550,11 +541,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(out);
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
Strides upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
|
@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
// Copy the input to be column contiguous
|
||||
flags.col_contiguous = num_matrices == 1;
|
||||
flags.row_contiguous = false;
|
||||
std::vector<size_t> strides = in.strides();
|
||||
auto strides = in.strides();
|
||||
strides[in.ndim() - 2] = 1;
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(
|
||||
|
@@ -174,19 +174,19 @@ void reduce_dispatch_min_max(
|
||||
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
const Shape& shape,
|
||||
const Strides& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
|
@@ -38,13 +38,10 @@ enum ReductionOpType {
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
@@ -55,10 +52,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
// Should this be in utils?
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides);
|
||||
const Shape& shape,
|
||||
const Strides& strides);
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
|
||||
@@ -113,9 +110,6 @@ void reduction_op(
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
@@ -135,7 +129,7 @@ void reduction_op(
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
@@ -181,7 +175,7 @@ void reduction_op(
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
@@ -211,7 +205,7 @@ void reduction_op(
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
|
@@ -4,11 +4,11 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
@@ -29,8 +29,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
Shape shape = {x.shape(axes[0])};
|
||||
Strides strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
@@ -69,7 +69,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
std::vector<std::pair<int, int64_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
@@ -93,8 +93,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
@@ -109,15 +109,15 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
int64_t size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t stride_i = x.strides()[i];
|
||||
int shape_i = x.shape(i);
|
||||
auto stride_i = x.strides()[i];
|
||||
auto shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
|
@@ -4,24 +4,22 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
std::tuple<int64_t, Strides> prepare_slice(
|
||||
const array& in,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides) {
|
||||
const Shape& start_indices,
|
||||
const Shape& strides) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
Strides inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
copy_needed |= strides[i] < 0;
|
||||
}
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
|
@@ -6,14 +6,14 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
std::tuple<int64_t, Strides> prepare_slice(
|
||||
const array& in,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides);
|
||||
const Shape& start_indices,
|
||||
const Shape& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out);
|
||||
|
@@ -25,7 +25,7 @@ struct StridedIterator {
|
||||
// Constructors
|
||||
StridedIterator() = default;
|
||||
|
||||
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)
|
||||
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
|
||||
: ptr_(ptr + offset * stride), stride_(stride) {}
|
||||
|
||||
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
|
||||
@@ -99,7 +99,7 @@ struct StridedIterator {
|
||||
}
|
||||
|
||||
private:
|
||||
size_t stride_;
|
||||
int64_t stride_;
|
||||
T* ptr_;
|
||||
};
|
||||
|
||||
@@ -120,11 +120,11 @@ void sort(const array& in, array& out, int axis) {
|
||||
auto remaining_strides = out.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
auto axis_stride = out.strides()[axis];
|
||||
auto axis_size = out.shape(axis);
|
||||
|
||||
// Perform sorting in place
|
||||
ContiguousIterator<size_t> src_it(
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
@@ -158,14 +158,14 @@ void argsort(const array& in, array& out, int axis) {
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
auto in_stride = in.strides()[axis];
|
||||
auto out_stride = out.strides()[axis];
|
||||
auto axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting
|
||||
ContiguousIterator<size_t> in_it(
|
||||
ContiguousIterator in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator<size_t> out_it(
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
@@ -208,13 +208,13 @@ void partition(const array& in, array& out, int axis, int kth) {
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
ContiguousIterator<size_t> src_it(
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
@@ -249,16 +249,16 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
auto in_stride = in.strides()[axis];
|
||||
auto out_stride = out.strides()[axis];
|
||||
auto axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition
|
||||
ContiguousIterator<size_t> in_it(
|
||||
ContiguousIterator in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator<size_t> out_it(
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
|
@@ -78,11 +78,11 @@ void ternary_op_dims(
|
||||
const T3* c,
|
||||
U* out,
|
||||
Op op,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& c_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& c_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@@ -164,10 +164,10 @@ void ternary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
|
||||
size_t stride = out_strides[ndim - 3];
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator c_it(shape, c_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
|
@@ -15,7 +15,7 @@ void move_or_copy(const array& in, array& out) {
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@@ -26,15 +26,13 @@ void move_or_copy(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
|
||||
collapse_contiguous_dims_impl(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<StrideT>>& strides,
|
||||
StrideT size_cap) {
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
int64_t size_cap) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
Shape to_collapse;
|
||||
if (shape.size() > 0) {
|
||||
if (shape[0] != 1) {
|
||||
to_collapse.push_back(0);
|
||||
@@ -43,7 +41,7 @@ collapse_contiguous_dims_impl(
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
size *= shape[i];
|
||||
for (const std::vector<StrideT>& st : strides) {
|
||||
for (const auto& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
|
||||
contiguous = false;
|
||||
size = shape[i];
|
||||
@@ -60,8 +58,8 @@ collapse_contiguous_dims_impl(
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<StrideT>> out_strides(strides.size());
|
||||
Shape out_shape;
|
||||
std::vector<Strides> out_strides(strides.size());
|
||||
for (int i = 0;;) {
|
||||
while (i < to_collapse.size() && to_collapse[i] == -1) {
|
||||
++i;
|
||||
@@ -76,7 +74,7 @@ collapse_contiguous_dims_impl(
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<StrideT>& st = strides[j];
|
||||
const auto& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
||||
}
|
||||
i = k + 1;
|
||||
@@ -91,29 +89,12 @@ collapse_contiguous_dims_impl(
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<int64_t>>& strides,
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>>& strides,
|
||||
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
|
||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides,
|
||||
StrideT size_cap) {
|
||||
std::vector<int> collapsed_shape;
|
||||
std::vector<StrideT> collapsed_strides;
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
int64_t size_cap) {
|
||||
Shape collapsed_shape;
|
||||
Strides collapsed_strides;
|
||||
|
||||
if (shape.size() > 0) {
|
||||
collapsed_shape.push_back(shape[0]);
|
||||
@@ -123,7 +104,7 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
|
||||
continue;
|
||||
} else if (
|
||||
strides[i] * shape[i] != collapsed_strides.back() ||
|
||||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
|
||||
collapsed_shape.back() * static_cast<int64_t>(shape[i]) > size_cap) {
|
||||
collapsed_shape.push_back(shape[i]);
|
||||
collapsed_strides.push_back(strides[i]);
|
||||
} else {
|
||||
@@ -136,25 +117,10 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
|
||||
return std::make_pair(collapsed_shape, collapsed_strides);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<int64_t>& strides,
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
|
||||
return collapse_contiguous_dims_impl<size_t>(
|
||||
a.shape(), a.strides(), size_cap);
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
|
||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -8,12 +8,9 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename StrideT>
|
||||
inline StrideT elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides) {
|
||||
StrideT loc = 0;
|
||||
inline int64_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
int64_t loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
loc += q_and_r.rem * strides[i];
|
||||
@@ -22,16 +19,15 @@ inline StrideT elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(int elem, const array& a) {
|
||||
inline int64_t elem_to_loc(int elem, const array& a) {
|
||||
if (a.flags().row_contiguous) {
|
||||
return elem;
|
||||
}
|
||||
return elem_to_loc(elem, a.shape(), a.strides());
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
std::vector<StrideT> strides(shape.size(), 1);
|
||||
inline Strides make_contiguous_strides(const Shape& shape) {
|
||||
Strides strides(shape.size(), 1);
|
||||
for (int i = shape.size() - 1; i > 0; i--) {
|
||||
strides[i - 1] = strides[i] * shape[i];
|
||||
}
|
||||
@@ -44,22 +40,15 @@ std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<int64_t>>& strides,
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>>& strides,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const std::vector<array>& xs,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max()) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
std::vector<Strides> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
@@ -73,19 +62,14 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
}
|
||||
|
||||
// The single array version of the above.
|
||||
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<int64_t>& strides,
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
template <typename StrideT>
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
@@ -102,7 +86,7 @@ struct ContiguousIterator {
|
||||
loc += strides_[i];
|
||||
}
|
||||
|
||||
void seek(StrideT n) {
|
||||
void seek(int64_t n) {
|
||||
loc = 0;
|
||||
for (int i = shape_.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(n, shape_[i]);
|
||||
@@ -128,32 +112,29 @@ struct ContiguousIterator {
|
||||
}
|
||||
|
||||
explicit ContiguousIterator(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides,
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
int dims)
|
||||
: shape_(shape.begin(), shape.begin() + dims),
|
||||
strides_(strides.begin(), strides.begin() + dims) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
pos_ = Shape(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
StrideT loc{0};
|
||||
int64_t loc{0};
|
||||
|
||||
private:
|
||||
std::vector<int> shape_;
|
||||
std::vector<StrideT> strides_;
|
||||
std::vector<int> pos_;
|
||||
Shape shape_;
|
||||
Strides strides_;
|
||||
Shape pos_;
|
||||
};
|
||||
|
||||
template <typename StrideT>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides) {
|
||||
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
size_t no_broadcast_data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
bool is_col_contiguous = true;
|
||||
|
||||
@@ -182,9 +163,15 @@ void move_or_copy(const array& in, array& out);
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
|
||||
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
} // namespace mlx::core
|
||||
|
@@ -75,19 +75,21 @@ void binary_op_gpu_inplace(
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
|
||||
} else {
|
||||
std::vector<size_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e, e);
|
||||
decltype(a.strides()) e{};
|
||||
return std::make_tuple(decltype(a.shape()){}, e, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
bool large;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread;
|
||||
if (bopt == BinaryOpType::General) {
|
||||
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
|
||||
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out.size() > INT32_MAX;
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = 1;
|
||||
}
|
||||
std::string kernel_name =
|
||||
|
@@ -67,7 +67,7 @@ inline void build_kernel(
|
||||
|
||||
if (add_indices) {
|
||||
os += fmt::format(
|
||||
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
@@ -81,7 +81,7 @@ inline void build_kernel(
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os += fmt::format(
|
||||
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
@@ -93,11 +93,11 @@ inline void build_kernel(
|
||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
std::string idx_type = use_big_index ? "size_t" : "uint";
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
if (contiguous && use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
@@ -144,20 +144,18 @@ inline void build_kernel(
|
||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||
if (ndim == 1) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||
os +=
|
||||
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||
} else if (ndim == 2) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
} else if (ndim == 3) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = (i + 1) * ndim;
|
||||
os += fmt::format(
|
||||
@@ -360,10 +358,10 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
std::vector<std::vector<size_t>> initial_strides;
|
||||
std::vector<Strides> initial_strides;
|
||||
initial_strides.push_back(outputs[0].strides());
|
||||
std::vector<int> shape;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
if (!contiguous) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
@@ -378,7 +376,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
Strides xstrides;
|
||||
int j = 0;
|
||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
||||
if (output_shape[j] == 1) {
|
||||
@@ -440,7 +438,7 @@ void Compiled::eval_gpu(
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
std::vector<size_t> in_strides;
|
||||
Strides in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
|
@@ -64,8 +64,8 @@ void explicit_gemm_conv_ND_gpu(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
|
||||
Shape wt_reshape{implicit_K, implicit_N};
|
||||
Strides wt_restride{1, implicit_K};
|
||||
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
|
||||
auto wt_flags = wt.flags();
|
||||
wt_flags.row_contiguous = false;
|
||||
@@ -147,10 +147,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
array wt_view(
|
||||
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
|
||||
wt_view.copy_shared_buffer(
|
||||
wt,
|
||||
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
|
||||
wt.flags(),
|
||||
wt.size());
|
||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||
|
||||
// Materialize
|
||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
||||
|
@@ -43,13 +43,12 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& strides_in_pre,
|
||||
const std::vector<stride_t>& strides_out_pre,
|
||||
const Shape& data_shape,
|
||||
const Strides& strides_in_pre,
|
||||
const Strides& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
@@ -68,8 +67,8 @@ void copy_gpu_inplace(
|
||||
/* size_cap = */ INT32_MAX);
|
||||
return std::make_tuple(shape, strides[0], strides[1]);
|
||||
} else {
|
||||
std::vector<stride_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e);
|
||||
Strides e{};
|
||||
return std::make_tuple(Shape{}, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
||||
@@ -124,8 +123,8 @@ void copy_gpu_inplace(
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
Strides strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
Strides strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
if (ndim > 3) {
|
||||
compute_encoder.set_vector_bytes(shape, ndim, 2);
|
||||
}
|
||||
@@ -180,14 +179,13 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
const Strides& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||
in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
|
@@ -8,13 +8,12 @@
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
@@ -32,7 +31,7 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
const Strides& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
@@ -363,7 +363,7 @@ void multi_upload_bluestein_fft(
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
|
||||
// Broadcast w_q and w_k to the batch size
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast({}, complex64, nullptr, {});
|
||||
array w_q_broadcast({}, complex64, nullptr, {});
|
||||
@@ -386,8 +386,8 @@ void multi_upload_bluestein_fft(
|
||||
copies.push_back(slice_temp);
|
||||
copies.push_back(conj_temp);
|
||||
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
||||
@@ -431,19 +431,19 @@ void multi_upload_bluestein_fft(
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> strides(in.ndim(), 1);
|
||||
Shape starts(in.ndim(), 0);
|
||||
Shape strides(in.ndim(), 1);
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
slice_gpu(temp1, out, rstarts, strides, s);
|
||||
} else if (real && inverse) {
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
auto inv_n = array({1.0f / n}, {1}, float32);
|
||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||
copies.push_back(temp_float);
|
||||
@@ -531,8 +531,8 @@ void fft_op(
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axis);
|
||||
Strides strides;
|
||||
int64_t cur_stride = x.shape(axis);
|
||||
for (int a = 0; a < x.ndim(); a++) {
|
||||
if (a == axis) {
|
||||
strides.push_back(1);
|
||||
@@ -777,7 +777,7 @@ void nd_fft_op(
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
|
@@ -53,9 +53,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
|
||||
bool large_src = src.size() > UINT32_MAX;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large_index = nidx && inputs[1].size() > INT32_MAX;
|
||||
bool large_src = src.size() > INT32_MAX;
|
||||
bool large_out = out.size() > INT32_MAX;
|
||||
bool large = large_index || large_src || large_out;
|
||||
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
@@ -65,7 +65,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_type_name,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "size_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
@@ -86,7 +86,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim,
|
||||
large ? "size_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
@@ -234,9 +234,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
break;
|
||||
}
|
||||
auto upd_contig = upd.flags().row_contiguous;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
|
||||
bool large_upd = upd.size() > UINT32_MAX;
|
||||
bool large_out = out.size() > INT32_MAX;
|
||||
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
|
||||
bool large_upd = upd.size() > INT32_MAX;
|
||||
bool large = large_out || large_idx || large_upd;
|
||||
std::string kernel_name = fmt::format(
|
||||
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
|
||||
@@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nidx,
|
||||
upd_contig ? "updc_true" : "updc_false",
|
||||
nwork,
|
||||
large ? "size_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
@@ -290,7 +290,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_arr,
|
||||
upd_contig,
|
||||
nwork,
|
||||
large ? "size_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
@@ -312,8 +312,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
Shape idx_shapes;
|
||||
Strides idx_strides;
|
||||
// To access .data() use char instead of bool
|
||||
// bool is 1 byte in Metal so this is safe
|
||||
std::vector<char> idx_contigs;
|
||||
@@ -332,7 +332,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 3);
|
||||
compute_encoder.set_bytes(stride_, 4);
|
||||
} else {
|
||||
@@ -347,7 +347,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 7);
|
||||
compute_encoder.set_bytes(stride_, 8);
|
||||
} else {
|
||||
|
@@ -11,13 +11,13 @@ gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device {outm_t}* out_mask [[buffer(20)]],
|
||||
const device {opm_t}* mat_mask [[buffer(21)]],
|
||||
const device {opm_t}* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
|
@@ -5,12 +5,12 @@ constexpr std::string_view gather_kernels = R"(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant int64_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int64_t* idx_strides [[buffer(8)]],
|
||||
const constant bool* idx_contigs [[buffer(9)]],
|
||||
const constant int& idx_ndim [[buffer(10)]],
|
||||
{4}
|
||||
@@ -38,15 +38,15 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant int64_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant int64_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int64_t* idx_strides [[buffer(12)]],
|
||||
const constant bool* idx_contigs [[buffer(13)]],
|
||||
const constant int& idx_ndim [[buffer(14)]],
|
||||
const constant size_t& idx_size [[buffer(15)]],
|
||||
|
@@ -10,12 +10,12 @@ template [[host_name("{name}")]]
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
@@ -43,7 +43,7 @@ block_masked_gemm<
|
||||
device {itype}* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
||||
|
@@ -52,7 +52,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source;
|
||||
@@ -74,7 +74,7 @@ void append_binary_kernels(
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g1large", "binary_g_nd1"},
|
||||
{"g2large", "binary_g_nd2"},
|
||||
{"g3large", "binary_g_nd3"},
|
||||
}};
|
||||
@@ -86,11 +86,13 @@ void append_binary_kernels(
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
|
||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
|
||||
}
|
||||
@@ -141,7 +143,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
{"g3large", "ternary_g_nd3"},
|
||||
}};
|
||||
@@ -150,11 +152,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
|
||||
return kernel_source;
|
||||
@@ -178,7 +182,7 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@@ -186,19 +190,23 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g1large_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
|
||||
kernel_source += get_template_definition(
|
||||
"gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
|
@@ -75,10 +75,10 @@ template <typename T, typename Op, int N_READS = 4>
|
||||
const device T* in [[buffer(0)]],
|
||||
device uint32_t* out [[buffer(1)]],
|
||||
const constant int* shape [[buffer(2)]],
|
||||
const constant size_t* in_strides [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant int64_t* in_strides [[buffer(3)]],
|
||||
const constant int64_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& ndim [[buffer(5)]],
|
||||
const constant size_t& axis_stride [[buffer(6)]],
|
||||
const constant int64_t& axis_stride [[buffer(6)]],
|
||||
const constant size_t& axis_size [[buffer(7)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
|
@@ -43,7 +43,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[0], b[offset]);
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[0]);
|
||||
}
|
||||
|
||||
@@ -65,49 +65,49 @@ template <typename T, typename U, typename Op>
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
constant const int64_t& a_stride,
|
||||
constant const int64_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const int64_t a_strides[2],
|
||||
constant const int64_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const int64_t a_strides[3],
|
||||
constant const int64_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
@@ -117,18 +117,18 @@ template <
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
|
@@ -9,21 +9,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
|
@@ -56,7 +56,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[0], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
@@ -70,7 +70,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[0]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
@@ -84,58 +84,58 @@ template <typename T, typename U, typename Op>
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
constant const int64_t& a_stride,
|
||||
constant const int64_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const int64_t a_strides[2],
|
||||
constant const int64_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const int64_t a_strides[3],
|
||||
constant const int64_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
@@ -147,19 +147,19 @@ template <
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
|
@@ -7,21 +7,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
|
@@ -22,7 +22,7 @@ template <typename T, typename U>
|
||||
device U* dst [[buffer(1)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ template <typename T, typename U>
|
||||
device U* dst [[buffer(1)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[offset]);
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
||||
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
@@ -65,7 +65,7 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
||||
IdxT dst_idx =
|
||||
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
@@ -80,7 +80,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc<int64_t, IdxT>(
|
||||
auto src_idx = elem_to_loc<IdxT>(
|
||||
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
||||
if (N == 1) {
|
||||
IdxT dst_idx =
|
||||
@@ -104,8 +104,8 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
|
||||
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@@ -116,8 +116,8 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@@ -128,8 +128,8 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z},
|
||||
src_shape,
|
||||
src_strides,
|
||||
|
@@ -9,7 +9,7 @@ METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant int64_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
@@ -27,7 +27,7 @@ METAL_FUNC void gather_impl(
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
idx_loc += indices.row_contiguous[i]
|
||||
? index.y
|
||||
: elem_to_loc<size_t, LocT>(
|
||||
: elem_to_loc<LocT>(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
@@ -39,7 +39,7 @@ METAL_FUNC void gather_impl(
|
||||
}
|
||||
|
||||
auto src_offset =
|
||||
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
||||
elem_to_loc<LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
LocT out_idx = index.z;
|
||||
if (IDX_NDIM == 1) {
|
||||
|
@@ -436,9 +436,9 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@@ -486,31 +486,21 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc \
|
||||
"_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
#define instantiate_gemv_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
instantiate_kernel( \
|
||||
"gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
||||
gemv, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
sm, \
|
||||
sn, \
|
||||
tm, \
|
||||
tn, \
|
||||
nc, \
|
||||
axpby)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
@@ -549,13 +539,13 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int64_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -571,8 +561,8 @@ template <
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
const constant auto* veci_bstrides = index_batch_strides;
|
||||
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
@@ -619,37 +609,14 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
|
||||
gemv_gather<itype, bm, bn, sm, sn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_bs_blocks(name, itype) \
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_kernel( \
|
||||
"gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
||||
gemv_gather, itype, bm, bn, sm, sn, tm, tn)
|
||||
|
||||
#define instantiate_gemv_bs_blocks(name, itype) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
|
||||
@@ -684,9 +651,9 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@@ -734,33 +701,14 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc \
|
||||
"_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv_t<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
instantiate_kernel( \
|
||||
"gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
||||
gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
|
||||
@@ -800,13 +748,13 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int64_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -822,8 +770,8 @@ template <
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
const constant auto* veci_bstrides = index_batch_strides;
|
||||
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
@@ -870,36 +818,14 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
|
||||
gemv_t_gather<itype, bm, bn, sm, sn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_bs_helper( \
|
||||
nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_kernel( \
|
||||
"gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
||||
gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)
|
||||
|
||||
#define instantiate_gemv_t_bs_blocks(name, itype) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \
|
||||
|
@@ -642,13 +642,13 @@ template <
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -673,8 +673,8 @@ template <
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
const constant auto* mask_strides_mat = mask_batch_strides;
|
||||
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
@@ -742,13 +742,13 @@ template <
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -773,8 +773,8 @@ template <
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
const constant auto* mask_strides_mat = mask_batch_strides;
|
||||
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
|
@@ -10,29 +10,11 @@
|
||||
|
||||
#define instantiate_gemv_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
template [[host_name("gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
|
||||
gemv_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const device outm_t* out_mask [[buffer(20)]], \
|
||||
const device opm_t* mat_mask [[buffer(21)]], \
|
||||
const device opm_t* vec_mask [[buffer(22)]], \
|
||||
const constant int* mask_strides [[buffer(23)]], \
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
instantiate_kernel( \
|
||||
"gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc, \
|
||||
gemv_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
|
||||
|
||||
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
@@ -61,29 +43,11 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
#define instantiate_gemv_t_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
|
||||
gemv_t_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const device outm_t* out_mask [[buffer(20)]], \
|
||||
const device opm_t* mat_mask [[buffer(21)]], \
|
||||
const device opm_t* vec_mask [[buffer(22)]], \
|
||||
const constant int* mask_strides [[buffer(23)]], \
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
instantiate_kernel( \
|
||||
"gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc, \
|
||||
gemv_t_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
|
||||
|
||||
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
|
@@ -8,7 +8,7 @@ template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const constant int64_t* strides;
|
||||
const constant bool* row_contiguous;
|
||||
const int ndim;
|
||||
};
|
||||
|
@@ -854,15 +854,17 @@ METAL_FUNC void qvm_impl(
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
constexpr int tn = 32 / pack_factor;
|
||||
constexpr int block_size = SIMD_SIZE;
|
||||
|
||||
const device uint8_t* ws = (const device uint8_t*)w;
|
||||
using W_T =
|
||||
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
||||
const device W_T* ws = (const device W_T*)w;
|
||||
|
||||
typedef float U;
|
||||
typedef struct {
|
||||
uint8_t wi[tn * bytes_per_pack];
|
||||
W_T wi[tn * bytes_per_pack];
|
||||
} vec_w;
|
||||
|
||||
thread vec_w w_local;
|
||||
@@ -1217,12 +1219,12 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
int output_stride,
|
||||
const constant int& x_batch_ndims,
|
||||
const constant int* x_shape,
|
||||
const constant size_t* x_strides,
|
||||
const constant int64_t* x_strides,
|
||||
const constant int& w_batch_ndims,
|
||||
const constant int* w_shape,
|
||||
const constant size_t* w_strides,
|
||||
const constant size_t* s_strides,
|
||||
const constant size_t* b_strides,
|
||||
const constant int64_t* w_strides,
|
||||
const constant int64_t* s_strides,
|
||||
const constant int64_t* b_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Set the input/output matrices
|
||||
uint32_t x_idx = tid.z;
|
||||
@@ -1246,6 +1248,41 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device T*& scales,
|
||||
device T*& y,
|
||||
int output_stride,
|
||||
const constant int& x_batch_ndims,
|
||||
const constant int* x_shape,
|
||||
const constant int64_t* x_strides,
|
||||
const constant int& w_batch_ndims,
|
||||
const constant int* w_shape,
|
||||
const constant int64_t* w_strides,
|
||||
const constant int64_t* s_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Set the input/output matrices
|
||||
uint32_t x_idx = tid.z;
|
||||
uint32_t w_idx = tid.z;
|
||||
if (x_batch_ndims == 1) {
|
||||
x += x_idx * x_strides[0];
|
||||
} else {
|
||||
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
||||
}
|
||||
if (w_batch_ndims == 1) {
|
||||
w += w_idx * w_strides[0];
|
||||
scales += w_idx * s_strides[0];
|
||||
} else {
|
||||
ulong2 idx = elem_to_loc_broadcast(
|
||||
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
|
||||
w += idx.x;
|
||||
scales += idx.y;
|
||||
}
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
@@ -1258,16 +1295,16 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
int output_stride,
|
||||
const constant int& batch_ndims,
|
||||
const constant int* batch_shape,
|
||||
const constant size_t* lhs_strides,
|
||||
const constant size_t* rhs_strides,
|
||||
const constant int64_t* lhs_strides,
|
||||
const constant int64_t* rhs_strides,
|
||||
const constant int& x_batch_ndims,
|
||||
const constant int* x_shape,
|
||||
const constant size_t* x_strides,
|
||||
const constant int64_t* x_strides,
|
||||
const constant int& w_batch_ndims,
|
||||
const constant int* w_shape,
|
||||
const constant size_t* w_strides,
|
||||
const constant size_t* s_strides,
|
||||
const constant size_t* b_strides,
|
||||
const constant int64_t* w_strides,
|
||||
const constant int64_t* s_strides,
|
||||
const constant int64_t* b_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Set the input/output matrices
|
||||
uint32_t x_idx;
|
||||
@@ -1311,12 +1348,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
@@ -1362,12 +1399,12 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -1413,12 +1450,12 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -1464,12 +1501,12 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -1515,12 +1552,12 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& final_block_size [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -1579,12 +1616,12 @@ template <
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -1637,12 +1674,12 @@ template <
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -1689,18 +1726,18 @@ template <typename T, int group_size, int bits>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const constant size_t* lhs_strides [[buffer(19)]],
|
||||
const constant size_t* rhs_strides [[buffer(20)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -1750,18 +1787,18 @@ template <typename T, int group_size, int bits>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const constant size_t* lhs_strides [[buffer(19)]],
|
||||
const constant size_t* rhs_strides [[buffer(20)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -1811,18 +1848,18 @@ template <typename T, int group_size, int bits>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const constant size_t* lhs_strides [[buffer(19)]],
|
||||
const constant size_t* rhs_strides [[buffer(20)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -1880,18 +1917,18 @@ template <
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const constant size_t* lhs_strides [[buffer(20)]],
|
||||
const constant size_t* rhs_strides [[buffer(21)]],
|
||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -1947,18 +1984,18 @@ template <
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const constant size_t* lhs_strides [[buffer(20)]],
|
||||
const constant size_t* rhs_strides [[buffer(21)]],
|
||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -2147,3 +2184,666 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int bits>
|
||||
inline vec<U, 4> partial_qdot_vec(const thread U* x, vec<uint32_t, 4> w) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
|
||||
vec<U, 4> accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
|
||||
for (int j = 0; j < 4; j++) {
|
||||
accum[i] +=
|
||||
(x[4 * j + 0] * (ws[j] & 0x03) + x[4 * j + 1] * (ws[j] & 0x0c) +
|
||||
x[4 * j + 2] * (ws[j] & 0x30) + x[4 * j + 3] * (ws[j] & 0xc0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto ws = as_type<vec<uint16_t, 2>>(w[i]);
|
||||
for (int j = 0; j < 2; j++) {
|
||||
accum[i] +=
|
||||
(x[4 * j + 0] * (ws[j] & 0x000f) + x[4 * j + 1] * (ws[j] & 0x00f0) +
|
||||
x[4 * j + 2] * (ws[j] & 0x0f00) + x[4 * j + 3] * (ws[j] & 0xf000));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
|
||||
for (int j = 0; j < 4; j++) {
|
||||
accum[i] += x[j] * ws[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return accum;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||
const device vec<uint32_t, 4>* w,
|
||||
const device vec<T, 4>* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int packs_per_thread = (bits == 2) ? 1 : 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
vec<U, results_per_simdgroup> result = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size * 2 / group_size;
|
||||
const int w_row = tid.x * num_simdgroups + simd_gid;
|
||||
const int out_row = w_row * results_per_simdgroup;
|
||||
|
||||
w += w_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += w_row * in_vec_size_g + 2 * (simd_lid / scale_step_per_thread);
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
// Load the input vector
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
// Load the scales and biases
|
||||
vec<T, 4> s = scales[0];
|
||||
vec<T, 4> b = scales[1];
|
||||
|
||||
// Load the weights and perform the partial dot product
|
||||
vec<U, 4> accum = 0;
|
||||
for (int pack = 0; pack < packs_per_thread; pack++) {
|
||||
accum +=
|
||||
partial_qdot_vec<T, U, bits>(x_thread + pack * pack_factor, w[pack]);
|
||||
}
|
||||
|
||||
// Finalize the dot product and accumulate it
|
||||
for (int i = 0; i < 4; i++) {
|
||||
result[i] += static_cast<U>(s[i]) * accum[i] + static_cast<U>(b[i]) * sum;
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
scales += 2 * block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
result = simd_sum(result);
|
||||
if (simd_lid == 0) {
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
y[row] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int results_per_simdgroup>
|
||||
METAL_FUNC void affine_packed_byte_qmv_fast_impl(
|
||||
const device uint8_t* w,
|
||||
const device vec<T, 2 * results_per_simdgroup>* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int packs_per_thread = 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = (bits == 3) ? 8 : 4;
|
||||
;
|
||||
constexpr int bytes_per_pack = 3;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
vec<U, results_per_simdgroup> result = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int scales_row = tid.x * num_simdgroups + simd_gid;
|
||||
const int out_row = scales_row * results_per_simdgroup;
|
||||
|
||||
w += out_row * in_vec_size_w + simd_lid * (packs_per_thread * bytes_per_pack);
|
||||
scales += scales_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
// Load the input vector
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
// Load the scales and biases
|
||||
vec<T, 2 * results_per_simdgroup> sb = scales[0];
|
||||
|
||||
// Load the weights and perform the partial dot product
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] += qdot<U, values_per_thread, bits>(
|
||||
w + row * in_vec_size_w, x_thread, sb[row], sb[2 + row], sum);
|
||||
}
|
||||
|
||||
w += block_size * bytes_per_pack / pack_factor;
|
||||
scales += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
y[row] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void affine_packed_qmv_fast(
|
||||
const device vec<uint32_t, 4>* w [[buffer(0)]],
|
||||
const device vec<T, 4>* scales [[buffer(1)]],
|
||||
const device T* x [[buffer(2)]],
|
||||
device T* y [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (bits & (bits - 1)) {
|
||||
affine_packed_byte_qmv_fast_impl<T, group_size, bits, 2>(
|
||||
(const device uint8_t*)w,
|
||||
scales,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
} else {
|
||||
affine_packed_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size,
|
||||
short bits>
|
||||
struct AffinePackedQuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
"The group size should be larger than the columns");
|
||||
static_assert(
|
||||
group_size % BCOLS == 0,
|
||||
"The group size should be divisible by the columns");
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
|
||||
MLX_MTL_CONST short pack_factor = 32 / bits;
|
||||
MLX_MTL_CONST short row_pack_factor = 4;
|
||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS * row_pack_factor / pack_factor;
|
||||
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
|
||||
MLX_MTL_CONST short TOTAL_INTS = BCOLS_PACKED * BROWS_PACKED;
|
||||
MLX_MTL_CONST short n_reads =
|
||||
(TOTAL_INTS < tgp_size) ? 1 : TOTAL_INTS / tgp_size;
|
||||
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||
|
||||
static_assert(
|
||||
n_reads <= row_pack_factor,
|
||||
"The loader only supports per thread reads <= row_pack_factor");
|
||||
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
short group_step_cnt;
|
||||
const int group_stride;
|
||||
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
const short bii;
|
||||
const short bjj;
|
||||
|
||||
const device uint32_t* src;
|
||||
const device T* scales;
|
||||
const device T* biases;
|
||||
threadgroup T* dst;
|
||||
|
||||
AffinePackedQuantizedBlockLoader(
|
||||
const device uint32_t* src_,
|
||||
const device T* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS_PACKED * src_ld),
|
||||
group_step_cnt(0),
|
||||
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(n_reads * thread_idx / BCOLS_PACKED),
|
||||
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
||||
bii(bi * row_pack_factor + bj % row_pack_factor),
|
||||
bjj(bj / row_pack_factor),
|
||||
src(src_ + bi * src_ld * row_pack_factor / pack_factor + bj),
|
||||
scales(
|
||||
scales_ + bi * 2 * src_ld * row_pack_factor / group_size +
|
||||
bj % row_pack_factor),
|
||||
biases(scales + row_pack_factor),
|
||||
dst(dst_ + bii * dst_ld + bjj * pack_factor) {}
|
||||
|
||||
void load_unsafe() const {
|
||||
if (bits == 2 && BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
T scale = scales[i];
|
||||
T bias = biases[i];
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
|
||||
}
|
||||
}
|
||||
|
||||
void load_safe(short2 src_tile_dim) const {
|
||||
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
T scale = scales[i];
|
||||
T bias = biases[i];
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
|
||||
}
|
||||
}
|
||||
|
||||
void next() {
|
||||
src += tile_stride;
|
||||
if (reduction_dim == 1) {
|
||||
if (group_steps > 1) {
|
||||
group_step_cnt++;
|
||||
if (group_step_cnt == group_steps) {
|
||||
group_step_cnt = 0;
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += group_stride;
|
||||
biases += group_stride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size,
|
||||
short bits>
|
||||
struct AffineScalesPackedQuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
"The group size should be larger than the columns");
|
||||
static_assert(
|
||||
group_size % BCOLS == 0,
|
||||
"The group size should be divisible by the columns");
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
|
||||
MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4;
|
||||
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
MLX_MTL_CONST short row_pack_factor = 2;
|
||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
|
||||
MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor;
|
||||
MLX_MTL_CONST short n_reads =
|
||||
(TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size;
|
||||
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
short group_step_cnt;
|
||||
const int group_stride;
|
||||
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
const short bii;
|
||||
|
||||
const device uint8_t* src;
|
||||
const device T* scales;
|
||||
const device T* biases;
|
||||
threadgroup T* dst;
|
||||
|
||||
AffineScalesPackedQuantizedBlockLoader(
|
||||
const device uint32_t* src_,
|
||||
const device T* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(
|
||||
reduction_dim ? BCOLS_PACKED * bytes_per_pack
|
||||
: BROWS * src_ld * bytes_per_pack / pack_factor),
|
||||
group_step_cnt(0),
|
||||
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(n_reads * thread_idx / BCOLS_PACKED),
|
||||
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
||||
bii(bi / row_pack_factor),
|
||||
src(((const device uint8_t*)src_) +
|
||||
bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack),
|
||||
scales(
|
||||
scales_ + bii * 2 * src_ld * row_pack_factor / group_size +
|
||||
bi % row_pack_factor),
|
||||
biases(scales + row_pack_factor),
|
||||
dst(dst_ + bi * dst_ld + bj * pack_factor) {}
|
||||
|
||||
void load_unsafe() const {
|
||||
if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
T scale = *scales;
|
||||
T bias = *biases;
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + bytes_per_pack * i),
|
||||
scale,
|
||||
bias,
|
||||
dst + i * pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
void load_safe(short2 src_tile_dim) const {
|
||||
if (TOTAL_READS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
T scale = scales[i];
|
||||
T bias = biases[i];
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + bytes_per_pack * i * src_ld),
|
||||
scale,
|
||||
bias,
|
||||
dst + i * dst_ld);
|
||||
}
|
||||
}
|
||||
|
||||
void next() {
|
||||
src += tile_stride;
|
||||
if (reduction_dim == 1) {
|
||||
if (group_steps > 1) {
|
||||
group_step_cnt++;
|
||||
if (group_step_cnt == group_steps) {
|
||||
group_step_cnt = 0;
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += group_stride;
|
||||
biases += group_stride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void affine_packed_qmm_t_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
threadgroup T* Ws,
|
||||
const constant int& K,
|
||||
const constant int& N,
|
||||
const constant int& M,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2;
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = mlx::steel::
|
||||
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||
using loader_x_t =
|
||||
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||
using loader_fully_packed_t = AffinePackedQuantizedBlockLoader<
|
||||
T,
|
||||
BN,
|
||||
BK,
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader<
|
||||
T,
|
||||
BN,
|
||||
BK,
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
using loader_w_t = typename ConditionalType<
|
||||
power_of_2_bits,
|
||||
loader_fully_packed_t,
|
||||
loader_scales_packed_t>::type;
|
||||
|
||||
// Set the block
|
||||
const int K_w =
|
||||
(power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32;
|
||||
const int K_g = K * 2 * row_pack_factor / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
const int packed_y_col = tid.x * (BN / row_pack_factor);
|
||||
|
||||
x += y_row * K;
|
||||
w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w;
|
||||
scales += packed_y_col * K_g;
|
||||
y += y_row * N + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
const short num_outs = min(BN, N - y_col);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
loader_w_t loader_w(w, scales, K, Ws, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
if (num_els < BM) {
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_safe(short2(BK, num_outs));
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_unsafe();
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_safe(short2(BK, num_outs));
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_unsafe();
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (num_els < BM || num_outs < BN) {
|
||||
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
|
||||
} else {
|
||||
mma_op.store_result(y, N);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void affine_packed_qmm_t(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* x [[buffer(2)]],
|
||||
device T* y [[buffer(3)]],
|
||||
const constant int& K [[buffer(4)]],
|
||||
const constant int& N [[buffer(5)]],
|
||||
const constant int& M [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
y,
|
||||
M * N,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
@@ -60,6 +60,14 @@
|
||||
bits, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_quantized_affine_packed(name, type, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||
@@ -96,12 +104,20 @@
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||
|
||||
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
|
||||
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, true) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, false) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, true) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, false)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
||||
instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||
instantiate_quantized_all_affine_packed(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
|
@@ -71,7 +71,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
constant const uint& bytes_per_key,
|
||||
constant const int& ndim,
|
||||
constant const int* key_shape,
|
||||
constant const size_t* key_strides,
|
||||
constant const int64_t* key_strides,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
|
@@ -53,32 +53,32 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, size_t, dim) \
|
||||
itype, otype, op, int64_t, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, size_t, dim)
|
||||
itype, otype, op, int64_t, dim)
|
||||
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
itype, otype, op, int, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
itype, otype, op, int64_t, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
itype, otype, op, int, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
itype, otype, op, int64_t, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
|
||||
@@ -95,18 +95,18 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, size_t, dim)
|
||||
itype, otype, op, int64_t, dim)
|
||||
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, size_t, dim)
|
||||
itype, otype, op, int64_t, dim)
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
@@ -125,7 +125,7 @@ instantiate_init_min_max(max, Max)
|
||||
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
|
||||
|
||||
#define instantiate_and_or(name, op) \
|
||||
instantiate_reduce_functions(name, bool_, bool, bool, op) \
|
||||
instantiate_reduce_functions(name, bool_, bool, bool, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, bool, op)
|
||||
|
@@ -5,12 +5,12 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int64_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
@@ -34,7 +34,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
@@ -100,10 +100,10 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
const constant size_t& out_size [[buffer(11)]],
|
||||
@@ -116,7 +116,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
const device T* row;
|
||||
|
||||
IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + lid.x;
|
||||
|
||||
U total = Op::init;
|
||||
@@ -164,12 +164,12 @@ template <
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int64_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
@@ -197,7 +197,7 @@ template <
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
@@ -303,12 +303,12 @@ template <
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int64_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
const constant size_t& out_size [[buffer(11)]],
|
||||
@@ -342,7 +342,7 @@ template <
|
||||
IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT block_idx = full_idx / IdxT(out_size);
|
||||
IdxT out_idx = full_idx % IdxT(out_size);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
|
@@ -98,11 +98,11 @@ template <
|
||||
METAL_FUNC void per_thread_row_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
const device T* in,
|
||||
const size_t row_idx,
|
||||
const int64_t row_idx,
|
||||
int blocks,
|
||||
int extra,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
const constant int64_t* strides,
|
||||
const constant int& ndim,
|
||||
uint lsize_x,
|
||||
uint lid_x) {
|
||||
@@ -199,13 +199,13 @@ template <
|
||||
[[kernel]] void row_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& row_size [[buffer(2)]],
|
||||
const constant size_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int64_t& row_size [[buffer(2)]],
|
||||
const constant int64_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
@@ -225,7 +225,7 @@ template <
|
||||
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
||||
// Simple loop over non_row_reductions and reduce the row in the thread.
|
||||
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint r = 0; r < non_row_reductions; r++) {
|
||||
row = in + loop.location();
|
||||
@@ -238,7 +238,7 @@ template <
|
||||
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
|
||||
// thread reduces every 32nd row and then a simple simd reduce.
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
loop.next(simd_lane_id, reduce_shape, reduce_strides);
|
||||
|
||||
@@ -260,14 +260,14 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT = size_t,
|
||||
typename IdxT = int64_t,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
[[kernel]] void row_reduce_simple(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant int64_t& out_size [[buffer(3)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@@ -314,13 +314,13 @@ template <
|
||||
[[kernel]] void row_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& row_size [[buffer(2)]],
|
||||
const constant size_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int64_t& row_size [[buffer(2)]],
|
||||
const constant int64_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
@@ -337,8 +337,7 @@ template <
|
||||
|
||||
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
|
||||
// needs a small refactor.
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim) +
|
||||
lid.x * N_READS;
|
||||
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim) + lid.x * N_READS;
|
||||
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
@@ -20,28 +20,4 @@ using namespace metal;
|
||||
instantiate_sdpa_vector_heads(float)
|
||||
instantiate_sdpa_vector_heads(bfloat16_t)
|
||||
instantiate_sdpa_vector_heads(float16_t)
|
||||
|
||||
// Quantized SDPA vector instantiations
|
||||
#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #group_size "_" #bits, \
|
||||
quant_sdpa_vector_2pass_1, type, head_dim, group_size, bits)
|
||||
|
||||
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
|
||||
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
|
||||
instantiate_quant_sdpa_vector(type, heads, group_size, 8)
|
||||
|
||||
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
|
||||
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
|
||||
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
|
||||
instantiate_quant_sdpa_vector_bits(type, heads, 128)
|
||||
|
||||
#define instantiate_quant_sdpa_vector_heads(type) \
|
||||
instantiate_quant_sdpa_vector_group_size(type, 64) \
|
||||
instantiate_quant_sdpa_vector_group_size(type, 128)
|
||||
|
||||
instantiate_quant_sdpa_vector_heads(float)
|
||||
instantiate_quant_sdpa_vector_heads(bfloat16_t)
|
||||
instantiate_quant_sdpa_vector_heads(float16_t)
|
||||
|
||||
// clang-format on
|
||||
|
@@ -16,11 +16,11 @@ METAL_FUNC void scatter_impl(
|
||||
const device T* updates,
|
||||
device mlx_atomic<T>* out,
|
||||
const constant int* upd_shape,
|
||||
const constant size_t* upd_strides,
|
||||
const constant int64_t* upd_strides,
|
||||
const constant size_t& upd_ndim,
|
||||
const constant size_t& upd_size,
|
||||
const constant int* out_shape,
|
||||
const constant size_t* out_strides,
|
||||
const constant int64_t* out_strides,
|
||||
const constant size_t& out_ndim,
|
||||
const constant int* axes,
|
||||
const constant size_t& idx_size,
|
||||
@@ -31,7 +31,7 @@ METAL_FUNC void scatter_impl(
|
||||
auto ind_idx = gid.y * NWORK;
|
||||
LocT out_offset = 0;
|
||||
if (upd_size > 1) {
|
||||
out_offset = elem_to_loc<size_t, LocT>(
|
||||
out_offset = elem_to_loc<LocT>(
|
||||
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ METAL_FUNC void scatter_impl(
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = indices.row_contiguous[i]
|
||||
? ind_idx
|
||||
: elem_to_loc<size_t, LocT>(
|
||||
: elem_to_loc<LocT>(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
@@ -52,8 +52,7 @@ METAL_FUNC void scatter_impl(
|
||||
}
|
||||
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
|
||||
if constexpr (!UPD_ROW_CONTIG) {
|
||||
upd_idx =
|
||||
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
upd_idx = elem_to_loc<LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
}
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
|
@@ -113,67 +113,6 @@ template <typename T, int D>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int elem_per_thread, int bits>
|
||||
METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
|
||||
U query_sum = 0;
|
||||
if (bits == 4) {
|
||||
for (int i = 0; i < elem_per_thread; i += 4) {
|
||||
q[i] = scale * queries[i];
|
||||
q[i + 1] = scale * queries[i + 1];
|
||||
q[i + 2] = scale * queries[i + 2];
|
||||
q[i + 3] = scale * queries[i + 3];
|
||||
query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
|
||||
q[i + 1] /= 16.0f;
|
||||
q[i + 2] /= 256.0f;
|
||||
q[i + 3] /= 4096.0f;
|
||||
}
|
||||
} else if (bits == 8) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
q[i] = scale * queries[i];
|
||||
query_sum += q[i];
|
||||
}
|
||||
}
|
||||
return query_sum;
|
||||
}
|
||||
|
||||
template <typename U, int elem_per_thread, int bits>
|
||||
METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
|
||||
if (bits == 4) {
|
||||
auto ks = (const device uint16_t*)keys;
|
||||
for (int i = 0; i < elem_per_thread / 4; i++) {
|
||||
k[4 * i] = ks[i] & 0x000f;
|
||||
k[4 * i + 1] = ks[i] & 0x00f0;
|
||||
k[4 * i + 2] = ks[i] & 0x0f00;
|
||||
k[4 * i + 3] = ks[i] & 0xf000;
|
||||
}
|
||||
} else if (bits == 8) {
|
||||
auto ks = (const device uint8_t*)keys;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = ks[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int elem_per_thread, int bits>
|
||||
METAL_FUNC void load_values(
|
||||
const device uint32_t* values,
|
||||
thread U* v,
|
||||
U value_scale,
|
||||
U value_bias) {
|
||||
auto vs = (const device uint8_t*)values;
|
||||
if (bits == 4) {
|
||||
U s[2] = {value_scale, value_scale / 16.0f};
|
||||
for (int i = 0; i < elem_per_thread / 2; i++) {
|
||||
v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
|
||||
v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
|
||||
}
|
||||
} else if (bits == 8) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
v[i] = value_scale * vs[i] + value_bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int D>
|
||||
[[kernel]] void sdpa_vector_2pass_1(
|
||||
const device T* queries [[buffer(0)]],
|
||||
@@ -351,158 +290,3 @@ template <typename T, int D>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int D, int group_size, int bits>
|
||||
[[kernel]] void quant_sdpa_vector_2pass_1(
|
||||
const device T* queries [[buffer(0)]],
|
||||
const device uint32_t* keys [[buffer(1)]],
|
||||
const device T* key_scales [[buffer(2)]],
|
||||
const device T* key_biases [[buffer(3)]],
|
||||
const device uint32_t* values [[buffer(4)]],
|
||||
const device T* value_scales [[buffer(5)]],
|
||||
const device T* value_biases [[buffer(6)]],
|
||||
device float* out [[buffer(7)]],
|
||||
device float* sums [[buffer(8)]],
|
||||
device float* maxs [[buffer(9)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant size_t& k_group_stride,
|
||||
const constant size_t& v_group_stride,
|
||||
const constant float& scale,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
constexpr int BN = 8;
|
||||
constexpr int BD = 4;
|
||||
constexpr int elem_per_thread = D / BD;
|
||||
const int stride = BN * D;
|
||||
constexpr int blocks = 32;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U q[elem_per_thread];
|
||||
thread U k[elem_per_thread];
|
||||
thread U v[elem_per_thread];
|
||||
thread U o[elem_per_thread];
|
||||
|
||||
threadgroup U outputs[BN * BD];
|
||||
threadgroup U max_scores[BN];
|
||||
threadgroup U sum_exp_scores[BN];
|
||||
|
||||
// Adjust positions
|
||||
const int block_idx = tid.z;
|
||||
const int head_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
queries += head_idx * D + quad_lid * elem_per_thread;
|
||||
|
||||
const int kv_idx =
|
||||
(block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread;
|
||||
const int packed_idx = kv_idx / pack_factor;
|
||||
const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size;
|
||||
const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size;
|
||||
|
||||
keys += kv_head_idx * k_stride + packed_idx;
|
||||
key_scales += k_group_idx;
|
||||
key_biases += k_group_idx;
|
||||
values += kv_head_idx * v_stride + packed_idx;
|
||||
value_scales += v_group_idx;
|
||||
value_biases += v_group_idx;
|
||||
|
||||
out += head_idx * blocks * D + block_idx * D + quad_lid * elem_per_thread;
|
||||
sums += head_idx * blocks + block_idx;
|
||||
maxs += head_idx * blocks + block_idx;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
U query_sum = load_queries<T, U, elem_per_thread, bits>(
|
||||
queries, q, static_cast<U>(scale));
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U sum_exp_score = 0;
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) {
|
||||
// Read the key
|
||||
load_keys<U, elem_per_thread, bits>(keys, k);
|
||||
|
||||
// Assume D % group_size == 0 so all the keys are in the same group
|
||||
U key_scale = key_scales[0];
|
||||
U key_bias = key_biases[0];
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = score * key_scale + query_sum * key_bias;
|
||||
score = quad_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
U value_scale = value_scales[0];
|
||||
U value_bias = value_biases[0];
|
||||
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * v[i];
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += blocks * stride / pack_factor;
|
||||
key_scales += blocks * stride / group_size;
|
||||
key_biases += blocks * stride / group_size;
|
||||
values += blocks * stride / pack_factor;
|
||||
value_scales += blocks * stride / group_size;
|
||||
value_biases += blocks * stride / group_size;
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
||||
// First let's communicate the max and sum_exp
|
||||
if (quad_lid == 0) {
|
||||
max_scores[quad_gid] = max_score;
|
||||
sum_exp_scores[quad_gid] = sum_exp_score;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
|
||||
U new_max = simd_max(max_score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
|
||||
sum_exp_score = simd_sum(sum_exp_score * factor);
|
||||
|
||||
// Write the sum and new max
|
||||
if (simd_gid == 0) {
|
||||
sums[0] = sum_exp_score;
|
||||
maxs[0] = new_max;
|
||||
}
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
outputs[quad_lid * BN + quad_gid] =
|
||||
o[i] * fast::exp(max_scores[quad_gid] - new_max);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (quad_gid == 0) {
|
||||
U output = outputs[quad_lid * BN];
|
||||
for (int j = 1; j < BN; j++) {
|
||||
output += outputs[quad_lid * BN + j];
|
||||
}
|
||||
out[i] = static_cast<T>(output);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
@@ -343,8 +343,8 @@ template <
|
||||
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const constant int* nc_shape [[buffer(6)]],
|
||||
const constant size_t* in_nc_strides [[buffer(7)]],
|
||||
const constant size_t* out_nc_strides [[buffer(8)]],
|
||||
const constant int64_t* in_nc_strides [[buffer(7)]],
|
||||
const constant int64_t* out_nc_strides [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
@@ -486,7 +486,7 @@ template <
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const constant int* nc_shape [[buffer(6)]],
|
||||
const constant size_t* nc_strides [[buffer(7)]],
|
||||
const constant int64_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
|
@@ -26,10 +26,10 @@ struct AttnParams {
|
||||
int NQ_aligned; ///< Number of full query blocks
|
||||
int NK_aligned; ///< Number of full key/value blocks
|
||||
|
||||
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
||||
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
|
@@ -14,9 +14,9 @@ struct MLXConvParams {
|
||||
const int pad[NDIM]; // Input padding
|
||||
const int kdil[NDIM]; // Kernel dilation
|
||||
const int idil[NDIM]; // Input dilation
|
||||
const size_t in_strides[NDIM + 2]; // In strides
|
||||
const size_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const size_t out_strides[NDIM + 2]; // Out strides
|
||||
const int64_t in_strides[NDIM + 2]; // In strides
|
||||
const int64_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const int64_t out_strides[NDIM + 2]; // Out strides
|
||||
const int groups; // Input channel groups
|
||||
const bool flip;
|
||||
};
|
||||
@@ -59,4 +59,4 @@ struct Conv2DGeneralBaseInfo {
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
} // namespace mlx
|
||||
|
@@ -38,12 +38,12 @@ template <
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
@@ -88,9 +88,8 @@ template <
|
||||
uint32_t indx_A, indx_B, indx_C;
|
||||
|
||||
if (has_batch) {
|
||||
const constant size_t* indx_A_bstrides = batch_strides;
|
||||
const constant size_t* indx_B_bstrides =
|
||||
batch_strides + params->batch_ndim;
|
||||
const constant auto* indx_A_bstrides = batch_strides;
|
||||
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 indx_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
@@ -102,7 +101,7 @@ template <
|
||||
indx_B = rhs_indices[indx_offsets.y];
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* indx_C_bstrides =
|
||||
const constant auto* indx_C_bstrides =
|
||||
indx_B_bstrides + params->batch_ndim;
|
||||
auto indx_offset_C = elem_to_loc(
|
||||
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
||||
@@ -120,18 +119,18 @@ template <
|
||||
// Translate indices to offsets
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
const constant int* batch_shape_A = operand_shape;
|
||||
const constant size_t* batch_strides_A = operand_strides;
|
||||
const constant auto* batch_strides_A = operand_strides;
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
||||
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
|
||||
if (use_out_source) {
|
||||
int batch_ndim_C = operand_batch_ndim.z;
|
||||
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
||||
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
||||
}
|
||||
|
||||
@@ -140,8 +139,8 @@ template <
|
||||
// Handle regular batch
|
||||
else {
|
||||
if (has_batch) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
@@ -150,7 +149,7 @@ template <
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
|
@@ -7,26 +7,10 @@
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h"
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2), function_constant(use_out_source)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_fused_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
|
@@ -56,7 +56,7 @@ block_masked_gemm(
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const device out_mask_t* out_mask [[buffer(10)]],
|
||||
const device op_mask_t* lhs_mask [[buffer(11)]],
|
||||
const device op_mask_t* rhs_mask [[buffer(12)]],
|
||||
@@ -104,7 +104,7 @@ block_masked_gemm(
|
||||
return;
|
||||
}
|
||||
|
||||
const constant size_t* mask_batch_strides =
|
||||
const constant auto* mask_batch_strides =
|
||||
batch_strides + 2 * params->batch_ndim;
|
||||
|
||||
if (params->batch_ndim > 1) {
|
||||
@@ -116,8 +116,8 @@ block_masked_gemm(
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs = mask_batch_strides;
|
||||
const constant size_t* mask_strides_rhs =
|
||||
const constant auto* mask_strides_lhs = mask_batch_strides;
|
||||
const constant auto* mask_strides_rhs =
|
||||
mask_strides_lhs + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
@@ -144,8 +144,8 @@ block_masked_gemm(
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
@@ -442,7 +442,7 @@ block_masked_gemm(
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const device bool* out_mask [[buffer(10)]],
|
||||
const device bool* lhs_mask [[buffer(11)]],
|
||||
const device bool* rhs_mask [[buffer(12)]],
|
||||
@@ -476,15 +476,15 @@ block_masked_gemm(
|
||||
}
|
||||
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* mask_batch_strides =
|
||||
const constant auto* mask_batch_strides =
|
||||
batch_strides + 2 * params->batch_ndim;
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs =
|
||||
const constant auto* mask_strides_lhs =
|
||||
mask_batch_strides + params->batch_ndim;
|
||||
const constant size_t* mask_strides_rhs =
|
||||
const constant auto* mask_strides_rhs =
|
||||
mask_strides_lhs + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
@@ -507,8 +507,8 @@ block_masked_gemm(
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
@@ -5,58 +5,45 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
||||
|
||||
#define instantiate_gemm( \
|
||||
outmaskname, \
|
||||
outmasktype, \
|
||||
opmaskname, \
|
||||
opmasktype, \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_gemm_block_outmask_" #outmaskname \
|
||||
"_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||
block_masked_gemm< \
|
||||
itype, \
|
||||
outmasktype, \
|
||||
opmasktype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const device outmasktype* out_mask [[buffer(10)]], \
|
||||
const device opmasktype* lhs_mask [[buffer(11)]], \
|
||||
const device opmasktype* rhs_mask [[buffer(12)]], \
|
||||
const constant int* mask_strides [[buffer(13)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
#define instantiate_gemm( \
|
||||
outmaskname, \
|
||||
outmasktype, \
|
||||
opmaskname, \
|
||||
opmasktype, \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_block_outmask_" #outmaskname \
|
||||
"_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname, \
|
||||
block_masked_gemm, \
|
||||
itype, \
|
||||
outmasktype, \
|
||||
opmasktype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned)
|
||||
|
||||
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
|
@@ -5,46 +5,39 @@
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
|
||||
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||
gemm_splitk< \
|
||||
itype, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device otype* C [[buffer(2)]], \
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_splitk_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname, \
|
||||
gemm_splitk, \
|
||||
itype, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned)
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
@@ -68,30 +61,13 @@ instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname \
|
||||
"_" #aname)]] [[kernel]] void \
|
||||
gemm_splitk_accum<atype, otype>( \
|
||||
const device atype* C_split [[buffer(0)]], \
|
||||
device otype* D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
uint2 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
|
||||
"_axbpy")]] [[kernel]] void \
|
||||
gemm_splitk_accum_axpby<atype, otype>( \
|
||||
const device atype* C_split [[buffer(0)]], \
|
||||
device otype* D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
const device otype* C [[buffer(5)]], \
|
||||
const constant int& ldc [[buffer(6)]], \
|
||||
const constant int& fdc [[buffer(7)]], \
|
||||
const constant float& alpha [[buffer(8)]], \
|
||||
const constant float& beta [[buffer(9)]], \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_splitk_accum_" #oname "_" #aname, \
|
||||
gemm_splitk_accum, atype, otype) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_splitk_accum_" #oname "_" #aname "_axbpy", \
|
||||
gemm_splitk_accum_axpby, atype, otype) \
|
||||
|
||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_accum(float16, half, float32, float);
|
||||
|
@@ -21,9 +21,9 @@ struct GEMMParams {
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const size_t batch_stride_a;
|
||||
const size_t batch_stride_b;
|
||||
const size_t batch_stride_d;
|
||||
const int64_t batch_stride_a;
|
||||
const int64_t batch_stride_b;
|
||||
const int64_t batch_stride_d;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
@@ -54,7 +54,7 @@ struct GEMMAddMMParams {
|
||||
const int ldc;
|
||||
const int fdc;
|
||||
|
||||
const size_t batch_stride_c;
|
||||
const int64_t batch_stride_c;
|
||||
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
@@ -7,8 +7,8 @@
|
||||
METAL_FUNC ulong2 elem_to_loc_broadcast(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
int ndim) {
|
||||
ulong loc_a{0};
|
||||
ulong loc_b{0};
|
||||
@@ -24,9 +24,9 @@ METAL_FUNC ulong2 elem_to_loc_broadcast(
|
||||
METAL_FUNC ulong3 elem_to_loc_broadcast(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int64_t* c_strides,
|
||||
int ndim) {
|
||||
ulong loc_a{0};
|
||||
ulong loc_b{0};
|
||||
|
@@ -18,72 +18,72 @@ template <typename T, typename Op>
|
||||
device T* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
d[offset] = Op()(a[offset], b[offset], c[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void ternary_g_nd1(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
constant const int64_t& a_strides,
|
||||
constant const int64_t& b_strides,
|
||||
constant const int64_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1<size_t, IdxT>(index, c_strides);
|
||||
auto a_idx = elem_to_loc_1<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1<IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1<IdxT>(index, c_strides);
|
||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void ternary_g_nd2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
constant const int64_t a_strides[2],
|
||||
constant const int64_t b_strides[2],
|
||||
constant const int64_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides);
|
||||
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2<IdxT>(index, c_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void ternary_g_nd3(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
constant const int64_t a_strides[3],
|
||||
constant const int64_t b_strides[3],
|
||||
constant const int64_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides);
|
||||
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3<IdxT>(index, c_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int N = 1, typename IdxT = size_t>
|
||||
template <typename T, typename Op, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void ternary_g(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int64_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
|
@@ -8,17 +8,17 @@
|
||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, int) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
|
@@ -14,7 +14,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* out,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
out[offset] = Op()(in[offset]);
|
||||
}
|
||||
|
||||
@@ -23,16 +23,16 @@ template <
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
typename IdxT = int64_t>
|
||||
[[kernel]] void unary_g(
|
||||
device const T* in,
|
||||
device U* out,
|
||||
constant const int* in_shape,
|
||||
constant const size_t* in_strides,
|
||||
constant const int64_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc<size_t, IdxT>(
|
||||
auto idx = elem_to_loc<IdxT>(
|
||||
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
||||
auto xshape = in_shape[ndim - 1];
|
||||
IdxT xstride = in_strides[ndim - 1];
|
||||
|
@@ -9,19 +9,19 @@
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
|
||||
instantiate_kernel( \
|
||||
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
||||
#define instantiate_unary_float(op) \
|
||||
#define instantiate_unary_float(op) \
|
||||
instantiate_unary_all_same(op, float16, half) \
|
||||
instantiate_unary_all_same(op, float32, float) \
|
||||
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_unary_types(op) \
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all_same(op, bool_, bool) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
|
@@ -89,25 +89,11 @@ struct Limits<complex64_t> {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with generic dims
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint elem,
|
||||
IdxT elem,
|
||||
constant const int* shape,
|
||||
constant const StrideT* strides,
|
||||
int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
StrideT elem,
|
||||
constant const int* shape,
|
||||
constant const StrideT* strides,
|
||||
constant const int64_t* strides,
|
||||
int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
@@ -118,11 +104,11 @@ METAL_FUNC IdxT elem_to_loc(
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const StrideT* strides,
|
||||
constant const int64_t* strides,
|
||||
int ndim) {
|
||||
IdxT loc =
|
||||
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
|
||||
@@ -136,18 +122,18 @@ METAL_FUNC IdxT elem_to_loc(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with fixed N dims
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) {
|
||||
return elem * IdxT(stride);
|
||||
}
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) {
|
||||
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
|
||||
}
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) {
|
||||
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
|
||||
elem.z * IdxT(strides[0]);
|
||||
}
|
||||
@@ -155,12 +141,12 @@ METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with generic dims
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const StrideT* a_strides,
|
||||
constant const StrideT* b_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
int ndim) {
|
||||
vec<IdxT, 2> loc = {
|
||||
IdxT(
|
||||
@@ -178,18 +164,21 @@ METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = size_t>
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int64_t* c_strides,
|
||||
int ndim) {
|
||||
vec<IdxT, 3> loc = {
|
||||
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
|
||||
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
|
||||
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
|
||||
IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
|
||||
IdxT(elem.y * IdxT(a_strides[ndim - 2])),
|
||||
IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
|
||||
IdxT(elem.y * IdxT(b_strides[ndim - 2])),
|
||||
IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
|
||||
IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * IdxT(a_strides[d]);
|
||||
@@ -213,7 +202,7 @@ struct LoopedElemToLoc {
|
||||
|
||||
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
void next(const constant int* shape, const constant int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -226,7 +215,7 @@ struct LoopedElemToLoc {
|
||||
}
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
void next(int n, const constant int* shape, const constant int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -262,19 +251,19 @@ struct LoopedElemToLoc<1, OffsetT, true> {
|
||||
|
||||
LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
void next(const constant int* shape, const constant int64_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
void next(int n, const constant int* shape, const constant int64_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
@@ -291,11 +280,11 @@ struct LoopedElemToLoc<1, OffsetT, false> {
|
||||
|
||||
LoopedElemToLoc(int) {}
|
||||
|
||||
void next(const constant int*, const constant size_t* strides) {
|
||||
void next(const constant int*, const constant int64_t* strides) {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
void next(int n, const constant int*, const constant size_t* strides) {
|
||||
void next(int n, const constant int*, const constant int64_t* strides) {
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
@@ -421,3 +410,14 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
|
||||
return complex64_t(
|
||||
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
|
||||
}
|
||||
|
||||
// std::conditional is not included with Metal
|
||||
template <bool condition, typename T, typename U>
|
||||
struct ConditionalType {
|
||||
using type = U;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct ConditionalType<true, T, U> {
|
||||
using type = T;
|
||||
};
|
||||
|
@@ -21,8 +21,8 @@ namespace {
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
@@ -30,8 +30,8 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
@@ -50,9 +50,9 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
@@ -60,9 +60,9 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
@@ -82,6 +82,25 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, array> check_transpose(
|
||||
std::vector<array>& copies,
|
||||
const Stream& s,
|
||||
const array& arr,
|
||||
bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -180,11 +199,11 @@ void steel_matmul_regular(
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<int> batch_shape,
|
||||
std::vector<size_t> batch_strides,
|
||||
size_t A_batch_stride,
|
||||
size_t B_batch_stride,
|
||||
size_t matrix_stride_out,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
std::vector<array>& copies) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
@@ -268,9 +287,9 @@ void steel_matmul_regular(
|
||||
/* const int ldd = */ ldd,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride,
|
||||
/* const size_t batch_stride_b = */ B_batch_stride,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ A_batch_stride,
|
||||
/* const int64_t batch_stride_b = */ B_batch_stride,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@@ -314,9 +333,9 @@ void steel_matmul(
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
std::vector<int> batch_shape /* = {} */,
|
||||
std::vector<size_t> A_batch_stride /* = {} */,
|
||||
std::vector<size_t> B_batch_stride /* = {} */) {
|
||||
Shape batch_shape /* = {} */,
|
||||
Strides A_batch_stride /* = {} */,
|
||||
Strides B_batch_stride /* = {} */) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
@@ -447,7 +466,7 @@ void steel_matmul(
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
auto batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
@@ -505,24 +524,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
@@ -662,9 +665,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* bool transpose_a = */ a_transposed,
|
||||
/* bool transpose_b = */ b_transposed,
|
||||
/* std::vector<array>& = */ copies,
|
||||
/* std::vector<int> batch_shape = */ batch_shape,
|
||||
/* std::vector<size_t> A_batch_stride = */ A_batch_stride,
|
||||
/* std::vector<size_t> B_batch_stride = */ B_batch_stride);
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -691,24 +694,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
array c = c_pre;
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
@@ -723,7 +710,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
|
||||
collapse_batches(a, b, c);
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * size_t(N);
|
||||
int64_t matrix_stride_out = M * static_cast<int64_t>(N);
|
||||
auto batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
@@ -1044,9 +1031,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const size_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const int64_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@@ -1054,7 +1041,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
GEMMAddMMParams params{
|
||||
/* const int ldc = */ ldc,
|
||||
/* const int fdc = */ fdc,
|
||||
/* const size_t batch_stride_c = */ C_batch_stride.back(),
|
||||
/* const int64_t batch_stride_c = */ C_batch_stride.back(),
|
||||
/* const float alpha = */ alpha_,
|
||||
/* const float beta = */ beta_};
|
||||
|
||||
@@ -1065,7 +1052,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
Strides batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
batch_strides.insert(
|
||||
@@ -1120,24 +1107,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
@@ -1156,20 +1127,20 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
std::vector<int> batch_shape{1};
|
||||
std::vector<size_t> A_batch_stride{0};
|
||||
std::vector<size_t> B_batch_stride{0};
|
||||
std::vector<size_t> outmask_bstride{0};
|
||||
std::vector<size_t> Amask_bstride{0};
|
||||
std::vector<size_t> Bmask_bstride{0};
|
||||
size_t A_batch_str = 0;
|
||||
size_t B_batch_str = 0;
|
||||
Shape batch_shape{1};
|
||||
Strides A_batch_stride{0};
|
||||
Strides B_batch_stride{0};
|
||||
Strides outmask_bstride{0};
|
||||
Strides Amask_bstride{0};
|
||||
Strides Bmask_bstride{0};
|
||||
int64_t A_batch_str = 0;
|
||||
int64_t B_batch_str = 0;
|
||||
|
||||
std::vector<size_t> batch_strides;
|
||||
Strides batch_strides;
|
||||
|
||||
if (out.ndim() > 2) {
|
||||
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
|
||||
std::vector<std::vector<size_t>> bstrides;
|
||||
Shape bshape{out.shape().begin(), out.shape().end() - 2};
|
||||
std::vector<Strides> bstrides;
|
||||
|
||||
for (auto& arr : inputs) {
|
||||
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
|
||||
@@ -1196,10 +1167,10 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
} else {
|
||||
batch_strides = std::vector<size_t>(inputs.size(), 0);
|
||||
batch_strides = Strides(inputs.size(), 0);
|
||||
}
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
int64_t matrix_stride_out = static_cast<int64_t>(M) * N;
|
||||
size_t batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1306,7 +1277,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Get mask params
|
||||
std::vector<int> mask_strides;
|
||||
std::vector<size_t> mask_batch_strides;
|
||||
Strides mask_batch_strides;
|
||||
if (has_out_mask) {
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
@@ -1436,9 +1407,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_str,
|
||||
/* const size_t batch_stride_b = */ B_batch_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ A_batch_str,
|
||||
/* const int64_t batch_stride_b = */ B_batch_str,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@@ -1524,24 +1495,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
@@ -1556,20 +1511,20 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
std::vector<size_t> batch_strides;
|
||||
Shape batch_shape = get_batch_dims(out.shape());
|
||||
Strides batch_strides;
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
rhs_indices.strides().begin(),
|
||||
rhs_indices.strides().end());
|
||||
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
@@ -1582,10 +1537,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int batch_ndim_B = b.ndim() - 2;
|
||||
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
|
||||
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
Shape batch_shape_A = get_batch_dims(a.shape());
|
||||
Strides batch_strides_A = get_batch_dims(a.strides());
|
||||
Shape batch_shape_B = get_batch_dims(b.shape());
|
||||
Strides batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
if (batch_ndim_A == 0) {
|
||||
batch_shape_A = {1};
|
||||
@@ -1597,7 +1552,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_strides_B = {0};
|
||||
}
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
auto matrix_stride_out = static_cast<int64_t>(M) * N;
|
||||
auto batch_size_out = out.size() / matrix_stride_out;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1801,9 +1756,9 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const size_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const int64_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ batch_ndim};
|
||||
|
@@ -21,11 +21,11 @@ void steel_matmul_regular(
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<int> batch_shape,
|
||||
std::vector<size_t> batch_strides,
|
||||
size_t A_batch_stride,
|
||||
size_t B_batch_stride,
|
||||
size_t matrix_stride_out,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
std::vector<array>& copies);
|
||||
|
||||
void steel_matmul(
|
||||
@@ -43,8 +43,8 @@ void steel_matmul(
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
std::vector<int> batch_shape = {},
|
||||
std::vector<size_t> A_batch_stride = {},
|
||||
std::vector<size_t> B_batch_stride = {});
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@@ -24,6 +25,25 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
enc.set_bytes(step, 1);
|
||||
}
|
||||
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
@@ -101,10 +121,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
std::vector<size_t> in_strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
std::vector<size_t> out_strides = out.strides();
|
||||
size_t axis_stride = in_strides[axis_];
|
||||
auto in_strides = in.strides();
|
||||
auto shape = in.shape();
|
||||
auto out_strides = out.strides();
|
||||
auto axis_stride = in_strides[axis_];
|
||||
size_t axis_size = shape[axis_];
|
||||
if (out_strides.size() == in_strides.size()) {
|
||||
out_strides.erase(out_strides.begin() + axis_);
|
||||
@@ -136,7 +156,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (ndim == 0) {
|
||||
// Pass place holders so metal doesn't complain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 2);
|
||||
compute_encoder.set_bytes(stride_, 3);
|
||||
compute_encoder.set_bytes(stride_, 4);
|
||||
@@ -210,6 +230,18 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto read_task = [out = out,
|
||||
@@ -304,27 +336,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto out_strides = make_contiguous_strides<size_t>(in.shape());
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
out_strides,
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
@@ -366,22 +378,25 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(out);
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_gpu_inplace<int64_t>(
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ upd_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user