mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Compare commits
35 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
83762691ba | ||
![]() |
2a41caa00e | ||
![]() |
6593281d25 | ||
![]() |
da98e8bce8 | ||
![]() |
be57a16a80 | ||
![]() |
1704809f29 | ||
![]() |
0cae0bdac8 | ||
![]() |
5a1a5d5ed1 | ||
![]() |
1683975acf | ||
![]() |
af705590ac | ||
![]() |
825124af8f | ||
![]() |
9c5e7da507 | ||
![]() |
481349495b | ||
![]() |
9daa6b003f | ||
![]() |
a3a632d567 | ||
![]() |
e496c5a4b4 | ||
![]() |
ea890d8710 | ||
![]() |
aa5d84f102 | ||
![]() |
f1606486d2 | ||
![]() |
87720a8908 | ||
![]() |
bb6565ef14 | ||
![]() |
7bb063bcb3 | ||
![]() |
b36dd472bb | ||
![]() |
167b759a38 | ||
![]() |
99b9868859 | ||
![]() |
6b2d5448f2 | ||
![]() |
eaf709b83e | ||
![]() |
f0e70afff0 | ||
![]() |
86984cad68 | ||
![]() |
fbc89e3ced | ||
![]() |
38c1e720c2 | ||
![]() |
600e87e03c | ||
![]() |
3836445241 | ||
![]() |
1d2c9d6a07 | ||
![]() |
e8ac6bd2f5 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
uv.lock
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
@@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
@@ -83,6 +84,10 @@ if(MLX_BUILD_METAL)
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
|
@@ -1,4 +1,6 @@
|
||||
include CMakeLists.txt
|
||||
include mlx.pc.in
|
||||
recursive-include mlx/ *
|
||||
include cmake/*
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
@@ -20,3 +20,5 @@ FFT
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
||||
fftshift
|
||||
ifftshift
|
||||
|
@@ -49,5 +49,16 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
endif()
|
||||
|
@@ -356,7 +356,7 @@ class array {
|
||||
}
|
||||
|
||||
enum Status {
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// The output of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
|
@@ -6,4 +6,5 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transpose.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/transpose.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -19,26 +20,19 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
// Calculate the contiguity based on the given shape and strides
|
||||
auto [ds, rc, cc] = check_contiguity(shape_, strides_);
|
||||
auto flags = in.flags();
|
||||
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
flags.contiguous = rc || cc;
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
// There is no easy way to compute the actual data size so we use out.size()
|
||||
// when the array is not contiguous.
|
||||
size_t data_size = flags.contiguous ? ds : out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
@@ -270,36 +264,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
transpose(inputs[0], out, axes_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
if (n > (1 << 26)) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
57
mlx/backend/common/transpose.cpp
Normal file
57
mlx/backend/common/transpose.cpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void transpose(const array& in, array& out, const std::vector<int>& axes) {
|
||||
Strides out_strides(out.ndim());
|
||||
for (int ax = 0; ax < axes.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
auto [_, rc, cc] = check_contiguity(out.shape(), out_strides);
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void as_transposed(array& out, const std::vector<int>& axes) {
|
||||
assert(out.data_size() == out.size() && out.flags().contiguous);
|
||||
|
||||
// Calculate the contiguous strides.
|
||||
Strides strides(out.ndim(), 1);
|
||||
for (int i = out.ndim() - 2; i >= 0; i--) {
|
||||
strides[i] = strides[i + 1] * out.shape(i);
|
||||
}
|
||||
|
||||
// Calculate the new strides for transposing.
|
||||
Strides new_strides;
|
||||
new_strides.reserve(out.ndim());
|
||||
for (auto ax : axes) {
|
||||
new_strides.push_back(strides[ax]);
|
||||
}
|
||||
|
||||
auto [ds, rc, cc] = check_contiguity(out.shape(), new_strides);
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
|
||||
out.copy_shared_buffer(out, new_strides, flags, ds);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
12
mlx/backend/common/transpose.h
Normal file
12
mlx/backend/common/transpose.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void transpose(const array& in, array& out, const std::vector<int>& axes);
|
||||
void as_transposed(array& out, const std::vector<int>& axes);
|
||||
|
||||
} // namespace mlx::core
|
@@ -132,6 +132,11 @@ struct ContiguousIterator {
|
||||
};
|
||||
|
||||
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
size_t no_broadcast_data_size = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
|
@@ -40,7 +40,8 @@ add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
|
11
mlx/backend/cpu/available.cpp
Normal file
11
mlx/backend/cpu/available.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/available.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
9
mlx/backend/cpu/available.h
Normal file
9
mlx/backend/cpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cpu
|
@@ -172,9 +172,12 @@ void binary_float(
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports non-complex floating point types.");
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@@ -40,7 +40,10 @@ struct CompilerCache {
|
||||
std::shared_mutex mtx;
|
||||
};
|
||||
|
||||
static CompilerCache cache{};
|
||||
static CompilerCache& cache() {
|
||||
static CompilerCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
@@ -56,14 +59,16 @@ void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
{
|
||||
std::shared_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
std::shared_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
std::unique_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string source_code = source_builder();
|
||||
@@ -120,10 +125,10 @@ void* compile(
|
||||
}
|
||||
|
||||
// load library
|
||||
cache.libs.emplace_back(shared_lib_path);
|
||||
cache().libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
@@ -131,7 +136,7 @@ void* compile(
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
cache.kernels.insert({kernel_name, fun});
|
||||
cache().kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
|
@@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
scan_dispatch<complex64_t, complex64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
@@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log1p(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto x = in.value.real();
|
||||
auto y = in.value.imag();
|
||||
auto zabs = std::abs(in.value);
|
||||
auto theta = std::atan2(y, x + 1);
|
||||
if (zabs < 0.5) {
|
||||
auto r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return Simd<T, 1>{T{x, theta}};
|
||||
}
|
||||
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
|
||||
} else {
|
||||
auto z0 = std::hypot(x + 1, y);
|
||||
return Simd<T, 1>{T{std::log(z0), theta}};
|
||||
}
|
||||
} else {
|
||||
return Simd<T, 1>{std::log1p(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
|
57
mlx/backend/cuda/CMakeLists.txt
Normal file
57
mlx/backend/cuda/CMakeLists.txt
Normal file
@@ -0,0 +1,57 @@
|
||||
# Filename rules in cuda backend:
|
||||
#
|
||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only kernel code should be put in kernels/ subdir.
|
||||
# * Files in kernels/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"75;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
"${MLX_CUDA_ARCHITECTURES}")
|
||||
|
||||
# Use fixed version of CCCL.
|
||||
FetchContent_Declare(
|
||||
cccl
|
||||
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
||||
FetchContent_MakeAvailable(cccl)
|
||||
target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include")
|
||||
|
||||
# Use fixed version of NVTX.
|
||||
FetchContent_Declare(
|
||||
nvtx3
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
|
||||
GIT_TAG v3.1.1
|
||||
GIT_SHALLOW TRUE
|
||||
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(nvtx3)
|
||||
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
||||
|
||||
# Make cuda runtime APIs available in non-cuda files.
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
154
mlx/backend/cuda/allocator.cpp
Normal file
154
mlx/backend/cuda/allocator.cpp
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
CudaAllocator::CudaAllocator() {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// TODO: Check memory limit.
|
||||
auto* buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
void CudaAllocator::free(Buffer buffer) {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([buffer]() { allocator().free(buffer); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = buf->size;
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ -= size;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::size(Buffer buffer) const {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return 0;
|
||||
}
|
||||
return buf->size;
|
||||
}
|
||||
|
||||
void CudaAllocator::register_this_thread() {
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_active_memory() const {
|
||||
return active_memory_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_peak_memory() const {
|
||||
return peak_memory_;
|
||||
}
|
||||
|
||||
void CudaAllocator::reset_peak_memory() {
|
||||
std::lock_guard lock(mutex_);
|
||||
peak_memory_ = 0;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_memory_limit() {
|
||||
return memory_limit_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||
std::lock_guard lock(mutex_);
|
||||
std::swap(limit, memory_limit_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
CudaAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
// can save some time at program exit.
|
||||
static CudaAllocator* allocator_ = new CudaAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
return cu::allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
size_t get_active_memory() {
|
||||
return cu::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return cu::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
return cu::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return cu::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return cu::allocator().get_memory_limit();
|
||||
}
|
||||
|
||||
// TODO: Implement buffer cache.
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
void clear_cache() {}
|
||||
|
||||
} // namespace mlx::core
|
58
mlx/backend/cuda/allocator.h
Normal file
58
mlx/backend/cuda/allocator.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Worker;
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores cuda-managed unified memory.
|
||||
struct CudaBuffer {
|
||||
void* data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
// Register current thread as safe to free buffers.
|
||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
size_t get_memory_limit();
|
||||
size_t set_memory_limit(size_t limit);
|
||||
|
||||
private:
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
};
|
||||
|
||||
CudaAllocator& allocator();
|
||||
|
||||
} // namespace mlx::core::cu
|
26
mlx/backend/cuda/copy.cpp
Normal file
26
mlx/backend/cuda/copy.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& data_shape,
|
||||
const Strides& strides_in_pre,
|
||||
const Strides& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
117
mlx/backend/cuda/device.cpp
Normal file
117
mlx/backend/cuda/device.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
|
||||
|
||||
void DeviceStream::synchronize() {
|
||||
cudaStreamSynchronize(stream_);
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::schedule_cuda_stream() {
|
||||
// TODO: Return a stream that maximizes parallelism.
|
||||
return stream_;
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::last_cuda_stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
CommandEncoder& DeviceStream::get_encoder() {
|
||||
if (!encoder_) {
|
||||
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||
}
|
||||
return *encoder_;
|
||||
}
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
// Validate the requirements of device.
|
||||
int attr = 0;
|
||||
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
|
||||
if (attr != 1) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Device {} does not support synchronization in managed memory.",
|
||||
device_));
|
||||
}
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
// We need to set/get current CUDA device very frequently, cache it to reduce
|
||||
// actual calls of CUDA APIs. This function assumes single-thread in host.
|
||||
static int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
||||
current = device_;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceStream& Device::get_stream(Stream s) {
|
||||
auto it = streams_.find(s.index);
|
||||
if (it == streams_.end()) {
|
||||
it = streams_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(DeviceStream& s)
|
||||
: device_(s.device()), stream_(s) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
|
||||
void CommandEncoder::end_encoding() {
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
|
||||
// There is no kernel running, run completion handlers immediately.
|
||||
if (!has_gpu_work_) {
|
||||
worker_.consume_in_this_thread();
|
||||
return;
|
||||
}
|
||||
has_gpu_work_ = false;
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.end_batch();
|
||||
|
||||
// Signaling kernel completion is expensive, delay until enough batches.
|
||||
// TODO: This number is arbitrarily picked, profile for a better stragety.
|
||||
if (worker_.uncommited_batches() > 8) {
|
||||
commit();
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
worker_.commit(stream_.last_cuda_stream());
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
if (it == devices.end()) {
|
||||
it = devices.try_emplace(device.index, device.index).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
DeviceStream& get_stream(Stream s) {
|
||||
return device(s.device).get_stream(s);
|
||||
}
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s) {
|
||||
return get_stream(s).get_encoder();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
} // namespace mlx::core
|
131
mlx/backend/cuda/device.h
Normal file
131
mlx/backend/cuda/device.h
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
class CommandEncoder;
|
||||
|
||||
class DeviceStream {
|
||||
public:
|
||||
explicit DeviceStream(Device& device);
|
||||
|
||||
DeviceStream(const DeviceStream&) = delete;
|
||||
DeviceStream& operator=(const DeviceStream&) = delete;
|
||||
|
||||
// Wait until kernels in the stream complete.
|
||||
void synchronize();
|
||||
|
||||
// Return a cuda stream for launching kernels.
|
||||
cudaStream_t schedule_cuda_stream();
|
||||
|
||||
// Return the last cuda stream used.
|
||||
cudaStream_t last_cuda_stream();
|
||||
|
||||
CommandEncoder& get_encoder();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
std::unique_ptr<CommandEncoder> encoder_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int device);
|
||||
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current cuda device, required by some cuda calls.
|
||||
void make_current();
|
||||
|
||||
DeviceStream& get_stream(Stream s);
|
||||
|
||||
int cuda_device() const {
|
||||
return device_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
std::unordered_map<int, DeviceStream> streams_;
|
||||
};
|
||||
|
||||
class CommandEncoder {
|
||||
public:
|
||||
explicit CommandEncoder(DeviceStream& stream);
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
void set_input_array(const array& arr) {}
|
||||
void set_output_array(const array& arr) {}
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void end_encoding();
|
||||
void commit();
|
||||
|
||||
// Schedule a cuda stream for |fun| to launch kernels, and check error
|
||||
// afterwards.
|
||||
template <typename F>
|
||||
void launch_kernel(F&& fun) {
|
||||
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void launch_kernel(cudaStream_t stream, F&& fun) {
|
||||
device_.make_current();
|
||||
fun(stream);
|
||||
check_cuda_error("kernel launch", cudaGetLastError());
|
||||
has_gpu_work_ = true;
|
||||
}
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
DeviceStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
bool has_gpu_work() const {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
Worker worker_;
|
||||
bool has_gpu_work_{false};
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device device);
|
||||
DeviceStream& get_stream(Stream s);
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
// Return an execution policy that does not sync for result.
|
||||
// Note that not all thrust APIs support async policy, confirm before using.
|
||||
inline auto thrust_policy(cudaStream_t stream) {
|
||||
// TODO: Connect thrust's custom allocator with mlx's allocator.
|
||||
return thrust::cuda::par_nosync.on(stream);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
35
mlx/backend/cuda/dtype_utils.cuh
Normal file
35
mlx/backend/cuda/dtype_utils.cuh
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
struct CTypeToCudaType {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<float16_t> {
|
||||
using type = __half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<bfloat16_t> {
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<complex64_t> {
|
||||
using type = cuComplex;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||
|
||||
} // namespace mlx::core
|
68
mlx/backend/cuda/eval.cpp
Normal file
68
mlx/backend/cuda/eval.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream s) {
|
||||
// Force initalization of cuda, so cuda runtime get destroyed at last.
|
||||
cudaFree(nullptr);
|
||||
// Ensure the static stream objects get created.
|
||||
cu::get_command_encoder(s);
|
||||
// The main thread is safe to free buffers.
|
||||
cu::allocator().register_this_thread();
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
nvtx3::scoped_range r("gpu::eval");
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
if (encoder.has_gpu_work()) {
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input.
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
||||
}
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::finalize");
|
||||
cu::get_command_encoder(s).commit();
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::synchronize");
|
||||
cu::get_stream(s).synchronize();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
265
mlx/backend/cuda/event.cu
Normal file
265
mlx/backend/cuda/event.cu
Normal file
@@ -0,0 +1,265 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CudaEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Cuda event managed with RAII.
|
||||
class CudaEventHandle {
|
||||
public:
|
||||
CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
|
||||
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
}
|
||||
|
||||
~CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
|
||||
}
|
||||
|
||||
CudaEventHandle(const CudaEventHandle&) = delete;
|
||||
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
|
||||
|
||||
operator cudaEvent_t() const {
|
||||
return event_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaEvent_t event_;
|
||||
};
|
||||
|
||||
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
|
||||
|
||||
void CudaEvent::wait() {
|
||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaEventSynchronize(*event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(cudaStream_t stream) {
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaStreamWaitEvent(stream, *event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
wait(cu::get_stream(s).last_cuda_stream());
|
||||
}
|
||||
}
|
||||
|
||||
void CudaEvent::record(cudaStream_t stream) {
|
||||
cudaEventRecord(*event_, stream);
|
||||
recorded_ = true;
|
||||
}
|
||||
|
||||
void CudaEvent::record(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
|
||||
} else {
|
||||
record(cu::get_stream(s).last_cuda_stream());
|
||||
}
|
||||
}
|
||||
|
||||
bool CudaEvent::completed() const {
|
||||
return cudaEventQuery(*event_) == cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SharedEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
uint64_t current;
|
||||
while ((current = ac->load()) < value) {
|
||||
ac->wait(current);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
ac->store(value);
|
||||
ac->notify_all();
|
||||
}
|
||||
|
||||
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_wait(ac, value);
|
||||
}
|
||||
|
||||
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_signal(ac, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
// Allocate cuda::atomic on managed memory.
|
||||
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
|
||||
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
|
||||
new (ac) Atomic(0);
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
|
||||
ptr->~Atomic();
|
||||
allocator::free(buffer);
|
||||
});
|
||||
}
|
||||
|
||||
void SharedEvent::wait(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||
event_wait(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { wait(stream, value); });
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
void SharedEvent::signal(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||
event_signal(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { signal(stream, value); });
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||
return ac_->load() >= value;
|
||||
}
|
||||
|
||||
uint64_t SharedEvent::value() const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||
return ac_->load();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Event implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
struct EventImpl {
|
||||
// CudaEvent is preferred when possible because it is fast, however we have
|
||||
// to fallback to SharedEvent in following cases:
|
||||
// 1. the event is used to wait/signal a cpu stream;
|
||||
// 2. signal value other than 1 has been specified.
|
||||
std::unique_ptr<cu::CudaEvent> cuda;
|
||||
std::unique_ptr<cu::SharedEvent> shared;
|
||||
|
||||
bool is_created() const {
|
||||
return cuda || shared;
|
||||
}
|
||||
|
||||
void ensure_created(Stream s, uint64_t signal_value) {
|
||||
if (is_created()) {
|
||||
return;
|
||||
}
|
||||
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
|
||||
nvtx3::mark("Using slow SharedEvent");
|
||||
shared = std::make_unique<cu::SharedEvent>();
|
||||
} else {
|
||||
cuda = std::make_unique<cu::CudaEvent>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Event::Event(Stream s) : stream_(s) {
|
||||
event_ = std::shared_ptr<void>(
|
||||
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait();
|
||||
} else {
|
||||
event->shared->wait(value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::wait(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait(s);
|
||||
} else {
|
||||
event->shared->wait(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
event->ensure_created(s, value());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->record(s);
|
||||
} else {
|
||||
event->shared->signal(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
bool Event::is_signaled() const {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
if (!event->is_created()) {
|
||||
return false;
|
||||
}
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
return event->cuda->recorded() && event->cuda->completed();
|
||||
} else {
|
||||
return event->shared->is_signaled(value());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
66
mlx/backend/cuda/event.h
Normal file
66
mlx/backend/cuda/event.h
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/atomic>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class CudaEventHandle;
|
||||
|
||||
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
|
||||
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
||||
class CudaEvent {
|
||||
public:
|
||||
CudaEvent();
|
||||
|
||||
void wait();
|
||||
void wait(cudaStream_t stream);
|
||||
void wait(Stream s);
|
||||
void record(cudaStream_t stream);
|
||||
void record(Stream s);
|
||||
|
||||
// Return whether the recorded kernels have completed. Note that this method
|
||||
// returns true if record() has not been called.
|
||||
bool completed() const;
|
||||
|
||||
bool recorded() const {
|
||||
return recorded_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool recorded_{false};
|
||||
std::shared_ptr<CudaEventHandle> event_;
|
||||
};
|
||||
|
||||
// Event that can synchronize between CPU and GPU. It is much slower than
|
||||
// CudaEvent so the latter should always be preferred when possible.
|
||||
class SharedEvent {
|
||||
public:
|
||||
using Atomic = cuda::atomic<uint64_t>;
|
||||
|
||||
SharedEvent();
|
||||
|
||||
void wait(uint64_t value);
|
||||
void wait(cudaStream_t stream, uint64_t value);
|
||||
void wait(Stream s, uint64_t value);
|
||||
void signal(uint64_t value);
|
||||
void signal(cudaStream_t stream, uint64_t value);
|
||||
void signal(Stream s, uint64_t value);
|
||||
bool is_signaled(uint64_t value) const;
|
||||
uint64_t value() const;
|
||||
|
||||
const std::shared_ptr<Atomic>& atomic() const {
|
||||
return ac_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Atomic> ac_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
70
mlx/backend/cuda/fence.cu
Normal file
70
mlx/backend/cuda/fence.cu
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
while (true) {
|
||||
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
|
||||
// it the load() may never return new value.
|
||||
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
|
||||
uint64_t current = ac->load();
|
||||
if (current >= value) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
busy_wait(ac, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct FenceImpl {
|
||||
uint32_t count;
|
||||
cu::SharedEvent event;
|
||||
};
|
||||
|
||||
Fence::Fence(Stream s) {
|
||||
fence_ = std::shared_ptr<void>(
|
||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Fence::wait(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
|
||||
// https://github.com/ml-explore/mlx/issues/2137
|
||||
const auto& ac = fence->event.atomic();
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [ac, count = fence->count]() {
|
||||
nvtx3::scoped_range r("Fence::wait()");
|
||||
busy_wait(ac.get(), count);
|
||||
});
|
||||
} else {
|
||||
nvtx3::scoped_range r("Fence::wait(s)");
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
|
||||
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
|
||||
});
|
||||
encoder.add_completed_handler([ac]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
void Fence::update(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->count++;
|
||||
fence->event.signal(s, fence->count);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
15
mlx/backend/cuda/kernels/arange.cuh
Normal file
15
mlx/backend/cuda/kernels/arange.cuh
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
struct Arange {
|
||||
const T start;
|
||||
const T step;
|
||||
|
||||
__device__ T operator()(uint32_t i) const {
|
||||
return start + i * step;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
107
mlx/backend/cuda/kernels/fp16_math.cuh
Normal file
107
mlx/backend/cuda/kernels/fp16_math.cuh
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Missing C++ operator overrides for CUDA 7.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
|
||||
#define MLX_DEFINE_BF16_OP(OP) \
|
||||
__forceinline__ __device__ __nv_bfloat16 operator OP( \
|
||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_BF16_CMP(OP) \
|
||||
__forceinline__ __device__ bool operator OP( \
|
||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_BF16_OP(+)
|
||||
MLX_DEFINE_BF16_OP(-)
|
||||
MLX_DEFINE_BF16_OP(*)
|
||||
MLX_DEFINE_BF16_OP(/)
|
||||
MLX_DEFINE_BF16_CMP(>)
|
||||
MLX_DEFINE_BF16_CMP(<)
|
||||
MLX_DEFINE_BF16_CMP(>=)
|
||||
MLX_DEFINE_BF16_CMP(<=)
|
||||
|
||||
#undef MLX_DEFINE_BF16_OP
|
||||
#undef MLX_DEFINE_BF16_CMP
|
||||
|
||||
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_integral_except =
|
||||
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_arithmetic_except =
|
||||
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
|
||||
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
|
||||
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
|
||||
return HALF2FLOAT(x) OP static_cast<float>(y); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
|
||||
return static_cast<float>(y) OP HALF2FLOAT(x); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
|
||||
|
||||
#undef MLX_DEFINE_HALF_OP
|
||||
#undef MLX_DEFINE_HALF_CMP
|
||||
|
||||
} // namespace mlx::core::cu
|
163
mlx/backend/cuda/primitives.cu
Normal file
163
mlx/backend/cuda/primitives.cu
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/dtype_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&, this](cudaStream_t stream) {
|
||||
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
CTYPE step =
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
cu::Arange<OutType>{
|
||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(Abs)
|
||||
NO_GPU(Add)
|
||||
NO_GPU(AddMM)
|
||||
NO_GPU(ArcCos)
|
||||
NO_GPU(ArcCosh)
|
||||
NO_GPU(ArcSin)
|
||||
NO_GPU(ArcSinh)
|
||||
NO_GPU(ArcTan)
|
||||
NO_GPU(ArcTan2)
|
||||
NO_GPU(ArcTanh)
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(ArgSort)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(BitwiseInvert)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Conjugate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Greater)
|
||||
NO_GPU(GreaterEqual)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Imag)
|
||||
NO_GPU(Less)
|
||||
NO_GPU(LessEqual)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(Log)
|
||||
NO_GPU(Log1p)
|
||||
NO_GPU(LogicalNot)
|
||||
NO_GPU(LogicalAnd)
|
||||
NO_GPU(LogicalOr)
|
||||
NO_GPU(LogAddExp)
|
||||
NO_GPU(LogSumExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Matmul)
|
||||
NO_GPU(Maximum)
|
||||
NO_GPU(Minimum)
|
||||
NO_GPU(Multiply)
|
||||
NO_GPU(Negative)
|
||||
NO_GPU(NotEqual)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU(Power)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Real)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Round)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
NO_GPU(Sinh)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU(Square)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(Subtract)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
15
mlx/backend/cuda/slicing.cpp
Normal file
15
mlx/backend/cuda/slicing.cpp
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void concatenate_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int axis,
|
||||
const Stream& s) {
|
||||
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
26
mlx/backend/cuda/utils.cpp
Normal file
26
mlx/backend/cuda/utils.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
CudaStream::CudaStream(cu::Device& device) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
||||
}
|
||||
|
||||
CudaStream::~CudaStream() {
|
||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
void check_cuda_error(const char* name, cudaError_t err) {
|
||||
if (err != cudaSuccess) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}", name, cudaGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
36
mlx/backend/cuda/utils.h
Normal file
36
mlx/backend/cuda/utils.h
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
}
|
||||
|
||||
// Cuda stream managed with RAII.
|
||||
class CudaStream {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
~CudaStream();
|
||||
|
||||
CudaStream(const CudaStream&) = delete;
|
||||
CudaStream& operator=(const CudaStream&) = delete;
|
||||
|
||||
operator cudaStream_t() const {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaStream_t stream_;
|
||||
};
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
} // namespace mlx::core
|
90
mlx/backend/cuda/worker.cpp
Normal file
90
mlx/backend/cuda/worker.cpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
Worker::Worker()
|
||||
: signal_stream_(device(mlx::core::Device::gpu)),
|
||||
worker_(&Worker::thread_fn, this) {}
|
||||
|
||||
Worker::~Worker() {
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
stop_ = true;
|
||||
}
|
||||
worker_event_.signal(batch_ + 1);
|
||||
worker_.join();
|
||||
}
|
||||
|
||||
void Worker::add_task(std::function<void()> task) {
|
||||
pending_tasks_.push_back(std::move(task));
|
||||
}
|
||||
|
||||
void Worker::consume_in_this_thread() {
|
||||
for (auto& task : pending_tasks_) {
|
||||
task();
|
||||
}
|
||||
pending_tasks_.clear();
|
||||
}
|
||||
|
||||
void Worker::end_batch() {
|
||||
batch_++;
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
worker_tasks_[batch_] = std::move(pending_tasks_);
|
||||
}
|
||||
uncommited_batches_++;
|
||||
}
|
||||
|
||||
void Worker::commit() {
|
||||
if (uncommited_batches_ == 0) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
worker_event_.signal(batch_);
|
||||
}
|
||||
|
||||
void Worker::commit(cudaStream_t stream) {
|
||||
if (uncommited_batches_ == 0) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
// Signal the |worker_event_| in |signal_stream_| after the kernels in
|
||||
// |stream_| finish running.
|
||||
signal_event_.record(stream);
|
||||
signal_event_.wait(signal_stream_);
|
||||
worker_event_.signal(signal_stream_, batch_);
|
||||
}
|
||||
|
||||
void Worker::thread_fn() {
|
||||
// The worker thread is safe to free buffers.
|
||||
allocator().register_this_thread();
|
||||
|
||||
while (!stop_) {
|
||||
uint64_t batch = worker_event_.value();
|
||||
Tasks tasks;
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
// Move tasks in signaled batches.
|
||||
auto end = worker_tasks_.upper_bound(batch);
|
||||
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
||||
if (tasks.empty()) {
|
||||
tasks = std::move(it->second);
|
||||
} else {
|
||||
std::move(
|
||||
it->second.begin(), it->second.end(), std::back_inserter(tasks));
|
||||
}
|
||||
}
|
||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||
}
|
||||
for (auto& task : tasks) {
|
||||
task();
|
||||
}
|
||||
worker_event_.wait(batch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/worker.h
Normal file
68
mlx/backend/cuda/worker.h
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Run tasks in worker thread, synchronized with cuda stream.
|
||||
class Worker {
|
||||
public:
|
||||
Worker();
|
||||
~Worker();
|
||||
|
||||
Worker(const Worker&) = delete;
|
||||
Worker& operator=(const Worker&) = delete;
|
||||
|
||||
// Add a pending |task| that will run when consumed or commited.
|
||||
void add_task(std::function<void()> task);
|
||||
|
||||
// Run pending tasks immediately in current thread.
|
||||
void consume_in_this_thread();
|
||||
|
||||
// Put pending tasks in a batch.
|
||||
void end_batch();
|
||||
|
||||
// Inform worker thread to run current batches now.
|
||||
void commit();
|
||||
|
||||
// Inform worker thread to run current batches after kernels in |stream|
|
||||
// finish running.
|
||||
void commit(cudaStream_t stream);
|
||||
|
||||
// Return how many batches have been added but not committed yet.
|
||||
size_t uncommited_batches() const {
|
||||
return uncommited_batches_;
|
||||
}
|
||||
|
||||
private:
|
||||
void thread_fn();
|
||||
|
||||
uint64_t batch_{0};
|
||||
size_t uncommited_batches_{0};
|
||||
|
||||
// Cuda stream and event for signaling kernel completion.
|
||||
CudaStream signal_stream_;
|
||||
CudaEvent signal_event_;
|
||||
|
||||
// Worker thread.
|
||||
SharedEvent worker_event_;
|
||||
std::thread worker_;
|
||||
std::mutex worker_mutex_;
|
||||
bool stop_{false};
|
||||
|
||||
// Tasks are put in |pending_tasks_| first, and then moved to
|
||||
// |worker_tasks_| when end_batch() is called.
|
||||
using Tasks = std::vector<std::function<void()>>;
|
||||
Tasks pending_tasks_;
|
||||
std::map<uint64_t, Tasks> worker_tasks_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
5
mlx/backend/gpu/CMakeLists.txt
Normal file
5
mlx/backend/gpu/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
9
mlx/backend/gpu/available.h
Normal file
9
mlx/backend/gpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::gpu
|
49
mlx/backend/gpu/copy.cpp
Normal file
49
mlx/backend/gpu/copy.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -5,6 +5,8 @@
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
@@ -8,14 +8,11 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
void new_stream(Stream stream);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
void eval(array& arr);
|
||||
void finalize(Stream s);
|
||||
void synchronize(Stream s);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
} // namespace mlx::core::gpu
|
222
mlx/backend/gpu/primitives.cpp
Normal file
222
mlx/backend/gpu/primitives.cpp
Normal file
@@ -0,0 +1,222 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#define MLX_PROFILER_RANGE(message)
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("AsType::eval_gpu");
|
||||
CopyType ctype =
|
||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(inputs[0], out, ctype);
|
||||
}
|
||||
|
||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
|
||||
concatenate_gpu(inputs, out, axis_, stream());
|
||||
}
|
||||
|
||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Copy::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Depends::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("Depends::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Full::eval_gpu");
|
||||
auto in = inputs[0];
|
||||
CopyType ctype;
|
||||
if (in.data_size() == 1) {
|
||||
ctype = CopyType::Scalar;
|
||||
} else if (in.flags().contiguous) {
|
||||
ctype = CopyType::Vector;
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
auto& val = inputs[1];
|
||||
|
||||
// Padding value must be a scalar
|
||||
assert(val.size() == 1);
|
||||
|
||||
// Padding value, input and output must be of the same type
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
MLX_PROFILER_RANGE("Split::eval_gpu");
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Slice::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Transpose::eval_gpu");
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("View::eval_gpu");
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
44
mlx/backend/gpu/slicing.cpp
Normal file
44
mlx/backend/gpu/slicing.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s) {
|
||||
// Fill output with val
|
||||
fill_gpu(val, out, s);
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -93,6 +93,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
#include "mlx/memory.h"
|
||||
|
||||
|
@@ -90,7 +90,7 @@ void binary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = 1;
|
||||
work_per_thread = get_work_per_thread(a.dtype());
|
||||
}
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
@@ -137,13 +137,20 @@ void binary_op_gpu_inplace(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
@@ -64,6 +64,7 @@ inline void build_kernel(
|
||||
cnt++);
|
||||
}
|
||||
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
if (add_indices) {
|
||||
os += fmt::format(
|
||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
@@ -83,6 +84,9 @@ inline void build_kernel(
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
} else {
|
||||
os += fmt::format(
|
||||
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||
@@ -92,13 +96,14 @@ inline void build_kernel(
|
||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
if (contiguous && use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
||||
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
|
||||
} else if (contiguous) {
|
||||
os += " uint index = N_ * pos.x;\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
" int xshape = output_shape[{0}];\n",
|
||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||
@@ -110,6 +115,9 @@ inline void build_kernel(
|
||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
}
|
||||
if (work_per_thread > 1 && contiguous) {
|
||||
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read constant / contiguous inputs in tmps
|
||||
std::vector<array> nc_inputs;
|
||||
@@ -193,7 +201,7 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
if (work_per_thread > 1 && !contiguous) {
|
||||
os +=
|
||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
@@ -272,6 +280,7 @@ void Compiled::eval_gpu(
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
|
||||
std::string kernel = metal::utils();
|
||||
concatenate(
|
||||
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
||||
@@ -284,7 +293,9 @@ void Compiled::eval_gpu(
|
||||
constant_ids_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false);
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
@@ -295,7 +306,8 @@ void Compiled::eval_gpu(
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ true);
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
for (int i = 1; i < 8; i++) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
@@ -468,6 +480,13 @@ void Compiled::eval_gpu(
|
||||
if (!contiguous) {
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
} else {
|
||||
auto size = outputs[0].data_size();
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(size, cnt++);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(size, cnt++);
|
||||
}
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
@@ -477,12 +496,13 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].data_size();
|
||||
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
|
||||
MTL::Size grid_dims = large
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
? get_2d_grid_dims(
|
||||
outputs[0].shape(), outputs[0].strides(), work_per_thread)
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
|
@@ -1,35 +1,15 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
@@ -104,6 +84,8 @@ void copy_gpu_inplace(
|
||||
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
work_per_thread = get_work_per_thread(in.dtype());
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||
@@ -165,39 +147,23 @@ void copy_gpu_inplace(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
@@ -214,14 +180,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
int work_per_thread = get_work_per_thread(val.dtype());
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
@@ -1,20 +1,20 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
|
||||
#define NS_PRIVATE_IMPLEMENTATION
|
||||
#define CA_PRIVATE_IMPLEMENTATION
|
||||
#define MTL_PRIVATE_IMPLEMENTATION
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
namespace {
|
||||
@@ -66,8 +66,8 @@ MTL::Library* try_load_bundle(
|
||||
if (bundle != nullptr) {
|
||||
std::string resource_path =
|
||||
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
|
||||
lib_name + ".metallib" auto [lib, error] =
|
||||
load_library_from_path(device, resource_path.c_str());
|
||||
lib_name + ".metallib";
|
||||
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
@@ -79,12 +79,18 @@ MTL::Library* try_load_bundle(
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name) {
|
||||
std::string lib_path = get_colocated_mtllib_path(lib_name);
|
||||
if (lib_path.size() != 0) {
|
||||
return load_library_from_path(device, lib_path.c_str());
|
||||
const std::string& relative_path) {
|
||||
std::string binary_dir = get_binary_directory();
|
||||
if (binary_dir.size() == 0) {
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
return {nullptr, nullptr};
|
||||
|
||||
auto path = fs::path(binary_dir) / relative_path;
|
||||
if (!path.has_extension()) {
|
||||
path.replace_extension(".metallib");
|
||||
}
|
||||
|
||||
return load_library_from_path(device, path.c_str());
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
@@ -99,7 +105,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL());
|
||||
library = try_load_bundle(device, bundle->resourceURL(), lib_name);
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
}
|
||||
@@ -109,33 +115,34 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
}
|
||||
|
||||
MTL::Library* load_default_library(MTL::Device* device) {
|
||||
NS::Error *error1, *error2, *error3;
|
||||
NS::Error* error[4];
|
||||
MTL::Library* lib;
|
||||
// First try the colocated mlx.metallib
|
||||
std::tie(lib, error1) = load_colocated_library(device, "mlx");
|
||||
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Then try default.metallib in a SwiftPM bundle if we have one
|
||||
std::tie(lib, error2) = load_swiftpm_library(device, "default");
|
||||
std::tie(lib, error[2]) = load_swiftpm_library(device, "default");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Finally try default_mtllib_path
|
||||
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
|
||||
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the default metallib. ";
|
||||
if (error1 != nullptr) {
|
||||
msg << error1->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error2 != nullptr) {
|
||||
msg << error2->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error3 != nullptr) {
|
||||
msg << error3->localizedDescription()->utf8String() << " ";
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (error[i] != nullptr) {
|
||||
msg << error[i]->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
@@ -156,6 +163,7 @@ MTL::Library* load_library(
|
||||
<< error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
// We have been given a path so try to load from lib_path / lib_name.metallib
|
||||
@@ -168,6 +176,7 @@ MTL::Library* load_library(
|
||||
<< "> with error " << error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Try to load the colocated library
|
||||
@@ -188,8 +197,8 @@ MTL::Library* load_library(
|
||||
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib " << lib_name << ".metallib. "
|
||||
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
|
||||
<< ">";
|
||||
<< "We attempted to load it from <" << get_binary_directory() << "/"
|
||||
<< lib_name << ".metallib" << ">";
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
msg << " and from the Swift PM bundle.";
|
||||
#endif
|
||||
@@ -760,42 +769,4 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||
NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||
device_info() {
|
||||
auto init_device_info = []()
|
||||
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto name = std::string(raw_device->name()->utf8String());
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||
|
||||
size_t rsrc_limit = 0;
|
||||
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
||||
if (rsrc_limit == 0) {
|
||||
rsrc_limit = 499000;
|
||||
}
|
||||
|
||||
return {
|
||||
{"device_name", name},
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize},
|
||||
{"resource_limit", rsrc_limit}};
|
||||
};
|
||||
static auto device_info_ = init_device_info();
|
||||
return device_info_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -21,18 +21,14 @@ namespace mlx::core::metal {
|
||||
|
||||
// Note, this function must be left inline in a header so that it is not
|
||||
// dynamically linked.
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
inline std::string get_binary_directory() {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
std::string directory;
|
||||
int success = dladdr((void*)get_binary_directory, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
directory = fs::path(info.dli_fname).remove_filename().c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
return directory;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
@@ -270,4 +266,6 @@ class Device {
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
102
mlx/backend/metal/eval.cpp
Normal file
102
mlx/backend/metal/eval.cpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
metal::device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||
std::ostringstream msg;
|
||||
msg << "[METAL] Command buffer execution failed: "
|
||||
<< cbuf->error()->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
|
||||
if (d.command_buffer_needs_commit(s.index)) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
check_error(cbuf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
d.end_encoding(s.index);
|
||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
cb->retain();
|
||||
d.end_encoding(s.index);
|
||||
d.commit_command_buffer(s.index);
|
||||
cb->waitUntilCompleted();
|
||||
check_error(cb);
|
||||
cb->release();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -139,7 +138,7 @@ void Fence::update(Stream stream, const array& x) {
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_bytes(nthreads, 1);
|
||||
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Barrier on previous kernels
|
||||
compute_encoder.barrier();
|
||||
|
@@ -1,16 +1,18 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/backend/common/transpose.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -27,7 +29,7 @@ using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
|
||||
// For strided reads/writes, coalesce at least this many complex64s
|
||||
#define MIN_COALESCE_WIDTH 4
|
||||
|
||||
inline const std::vector<int> supported_radices() {
|
||||
inline constexpr std::array<int, 9> supported_radices() {
|
||||
// Ordered by preference in decomposition.
|
||||
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
|
||||
}
|
||||
@@ -49,6 +51,35 @@ std::vector<int> prime_factors(int n) {
|
||||
return factors;
|
||||
}
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
|
||||
std::vector<int> stockham_decompose(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::vector<int> steps(radices.size(), 0);
|
||||
int orig_n = n;
|
||||
|
||||
for (int i = 0; i < radices.size(); i++) {
|
||||
int radix = radices[i];
|
||||
|
||||
// Manually tuned radices for powers of 2
|
||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||
continue;
|
||||
}
|
||||
|
||||
while (n % radix == 0) {
|
||||
steps[i] += 1;
|
||||
n /= radix;
|
||||
if (n == 1) {
|
||||
return steps;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
struct FourStepParams {
|
||||
bool required = false;
|
||||
bool first_step = true;
|
||||
@@ -65,9 +96,10 @@ void fft_op(
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
struct FFTPlan {
|
||||
struct OldFFTPlan {
|
||||
int n = 0;
|
||||
// Number of steps for each radix in the Stockham decomposition
|
||||
std::vector<int> stockham;
|
||||
@@ -82,9 +114,104 @@ struct FFTPlan {
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
class FFTPlan {
|
||||
public:
|
||||
enum FFTType {
|
||||
UNSUPPORTED,
|
||||
NOOP,
|
||||
STOCKHAM,
|
||||
RADER,
|
||||
BLUESTEIN,
|
||||
MULTIUPLOAD_BLUESTEIN,
|
||||
SMALL_FOUR_STEP,
|
||||
LARGE_FOUR_STEP
|
||||
};
|
||||
|
||||
FFTPlan(int n) : n_(n) {
|
||||
// NOOP
|
||||
if (n == 1) {
|
||||
type_ = NOOP;
|
||||
}
|
||||
|
||||
// Too large for Stockham so do four step fft for powers of 2
|
||||
else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
if (n <= 1 << 20) {
|
||||
type_ = SMALL_FOUR_STEP;
|
||||
n2_ = n > 65536 ? 1024 : 64;
|
||||
n1_ = n / n2_;
|
||||
steps1_ = stockham_decompose(n1_);
|
||||
steps2_ = stockham_decompose(n2_);
|
||||
} else {
|
||||
type_ = LARGE_FOUR_STEP;
|
||||
}
|
||||
}
|
||||
|
||||
// Too large and not power of 2 so do multi-upload Bluestein fft
|
||||
else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
}
|
||||
|
||||
// Stockham fft
|
||||
else if (auto steps = stockham_decompose(n); steps.size() > 0) {
|
||||
type_ = STOCKHAM;
|
||||
steps_ = steps;
|
||||
}
|
||||
|
||||
// Add rader but for now simply fall back to bluestein when stockham not
|
||||
// posssible
|
||||
else if (n > MAX_BLUESTEIN_FFT_SIZE) {
|
||||
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
} else {
|
||||
type_ = BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
steps_ = stockham_decompose(bluestein_n_);
|
||||
}
|
||||
}
|
||||
|
||||
FFTType type() const {
|
||||
return type_;
|
||||
}
|
||||
|
||||
int size() const {
|
||||
return n_;
|
||||
}
|
||||
|
||||
const std::vector<int>& steps() const {
|
||||
return steps_;
|
||||
}
|
||||
|
||||
int first_size() const {
|
||||
return n1_;
|
||||
}
|
||||
|
||||
const std::vector<int>& first_steps() const {
|
||||
return steps1_;
|
||||
}
|
||||
|
||||
int second_size() const {
|
||||
return n2_;
|
||||
}
|
||||
|
||||
const std::vector<int>& second_steps() const {
|
||||
return steps2_;
|
||||
}
|
||||
|
||||
int bluestein_size() const {
|
||||
return bluestein_n_;
|
||||
}
|
||||
|
||||
private:
|
||||
int n_;
|
||||
FFTType type_;
|
||||
std::vector<int> steps_;
|
||||
int n1_;
|
||||
std::vector<int> steps1_;
|
||||
int n2_;
|
||||
std::vector<int> steps2_;
|
||||
int bluestein_n_;
|
||||
};
|
||||
|
||||
std::vector<int> plan_stockham_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
@@ -110,15 +237,12 @@ std::vector<int> plan_stockham_fft(int n) {
|
||||
throw std::runtime_error("Unplannable");
|
||||
}
|
||||
|
||||
FFTPlan plan_fft(int n) {
|
||||
OldFFTPlan plan_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> radices_set(radices.begin(), radices.end());
|
||||
|
||||
FFTPlan plan;
|
||||
OldFFTPlan plan;
|
||||
plan.n = n;
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
auto factors = prime_factors(n);
|
||||
int remaining_n = n;
|
||||
|
||||
// Four Step FFT when N is too large for shared mem.
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
@@ -128,16 +252,20 @@ FFTPlan plan_fft(int n) {
|
||||
plan.n2 = n > 65536 ? 1024 : 64;
|
||||
plan.n1 = n / plan.n2;
|
||||
return plan;
|
||||
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
}
|
||||
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
// Otherwise we use a multi-upload Bluestein's
|
||||
plan.four_step = true;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
return plan;
|
||||
}
|
||||
|
||||
int remaining_n = n;
|
||||
auto factors = prime_factors(n);
|
||||
for (int factor : factors) {
|
||||
// Make sure the factor is a supported radix
|
||||
if (radices_set.find(factor) == radices_set.end()) {
|
||||
if (std::find(radices.begin(), radices.end(), factor) == radices.end()) {
|
||||
// We only support a single Rader factor currently
|
||||
// TODO(alexbarron) investigate weirdness with large
|
||||
// Rader sizes -- possibly a compiler issue?
|
||||
@@ -154,7 +282,7 @@ FFTPlan plan_fft(int n) {
|
||||
for (int rf : rader_factors) {
|
||||
// We don't nest Rader's algorithm so if `factor - 1`
|
||||
// isn't Stockham decomposable we give up and do Bluestein's.
|
||||
if (radices_set.find(rf) == radices_set.end()) {
|
||||
if (std::find(radices.begin(), radices.end(), rf) == radices.end()) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
@@ -172,7 +300,7 @@ FFTPlan plan_fft(int n) {
|
||||
return plan;
|
||||
}
|
||||
|
||||
int compute_elems_per_thread(FFTPlan plan) {
|
||||
int compute_elems_per_thread(OldFFTPlan plan) {
|
||||
// Heuristics for selecting an efficient number
|
||||
// of threads to use for a particular mixed-radix FFT.
|
||||
auto n = plan.n;
|
||||
@@ -355,9 +483,11 @@ void multi_upload_bluestein_fft(
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
OldFFTPlan& plan,
|
||||
std::vector<array>& copies,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||
// algorithm
|
||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||
@@ -420,6 +550,7 @@ void multi_upload_bluestein_fft(
|
||||
/*real=*/false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/false,
|
||||
d,
|
||||
s);
|
||||
copies.push_back(pad_temp1);
|
||||
|
||||
@@ -435,6 +566,7 @@ void multi_upload_bluestein_fft(
|
||||
/* real= */ false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/true,
|
||||
d,
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
@@ -480,7 +612,7 @@ void four_step_fft(
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
OldFFTPlan& plan,
|
||||
std::vector<array>& copies,
|
||||
const Stream& s,
|
||||
bool in_place) {
|
||||
@@ -493,7 +625,15 @@ void four_step_fft(
|
||||
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
in,
|
||||
temp,
|
||||
axis,
|
||||
inverse,
|
||||
real,
|
||||
four_step_params,
|
||||
/*inplace=*/false,
|
||||
d,
|
||||
s);
|
||||
four_step_params.first_step = false;
|
||||
fft_op(
|
||||
temp,
|
||||
@@ -503,6 +643,7 @@ void four_step_fft(
|
||||
real,
|
||||
four_step_params,
|
||||
/*inplace=*/in_place,
|
||||
d,
|
||||
s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
@@ -518,9 +659,8 @@ void fft_op(
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
if (n == 1) {
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -755,57 +895,517 @@ void fft_op(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
inline int compute_elems_per_thread(int n, const std::vector<int>& steps) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> used_radices;
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
if (steps[i] > 0) {
|
||||
used_radices.insert(radices[i % radices.size()]);
|
||||
}
|
||||
}
|
||||
|
||||
// Manual tuning for 7/11/13
|
||||
if (used_radices.find(7) != used_radices.end() &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return 7;
|
||||
} else if (
|
||||
used_radices.find(11) != used_radices.end() &&
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return 11;
|
||||
}
|
||||
|
||||
// TODO(alexbarron) Some really weird stuff is going on
|
||||
// for certain `elems_per_thread` on large composite n.
|
||||
// Possibly a compiler issue?
|
||||
if (n == 3159)
|
||||
return 13;
|
||||
if (n == 3645)
|
||||
return 5;
|
||||
if (n == 3969)
|
||||
return 7;
|
||||
if (n == 1982)
|
||||
return 5;
|
||||
|
||||
if (used_radices.size() == 1) {
|
||||
return *(used_radices.begin());
|
||||
}
|
||||
if (used_radices.size() == 2 &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
|
||||
}
|
||||
|
||||
// In all other cases use the second smallest radix.
|
||||
return *(++used_radices.begin());
|
||||
}
|
||||
|
||||
inline array ensure_fastest_moving_axis(
|
||||
const array& x,
|
||||
int axis,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// The axis is already with a stride of 1 so check that we have no overlaps
|
||||
// and broadcasting and avoid the copy.
|
||||
if (x.strides(axis) == 1) {
|
||||
// This is a fairly strict test perhaps consider relaxing it in the future.
|
||||
if (x.flags().row_contiguous || x.flags().col_contiguous) {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
// To make it the fastest moving axis simply transpose it, then copy it and
|
||||
// then transpose it back.
|
||||
|
||||
// Transpose it
|
||||
std::vector<int> axes(x.ndim(), 0);
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ax + 1;
|
||||
}
|
||||
axes.back() = axis;
|
||||
Shape xtshape;
|
||||
xtshape.reserve(axes.size());
|
||||
for (auto ax : axes) {
|
||||
xtshape.push_back(x.shape(ax));
|
||||
}
|
||||
array xt(xtshape, x.dtype(), nullptr, {});
|
||||
transpose(x, xt, axes);
|
||||
|
||||
// Copy it
|
||||
array xtc(xt.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(
|
||||
xt,
|
||||
xtc,
|
||||
xt.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
d.add_temporary(xtc, s.index);
|
||||
|
||||
// Transpose it
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ((ax == axis) ? axes.size() - 1 : ax - 1);
|
||||
}
|
||||
array y(x.shape(), x.dtype(), nullptr, {});
|
||||
transpose(xtc, y, axes);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
inline void prepare_output_array(const array& in, array& out, int axis) {
|
||||
// Prepare the output array such that it matches the input in terms of
|
||||
// stride ordering. Namely we might have moved `axis` around in the `in`
|
||||
// array. We must do the same in `out`. The difference is that we don't have
|
||||
// to copy anything because `out` contains garbage at the moment.
|
||||
|
||||
if (in.flags().row_contiguous && out.flags().row_contiguous) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> axes(out.ndim(), 0);
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ax + 1;
|
||||
}
|
||||
axes.back() = axis;
|
||||
as_transposed(out, axes);
|
||||
}
|
||||
|
||||
void fft_stockham_inplace(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Prepare the arguments for stockham fft
|
||||
int n = plan.size();
|
||||
bool power_of_2 = is_power_of_2(n);
|
||||
int total_batch_size =
|
||||
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
||||
auto& steps = plan.steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(kname, "fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def =
|
||||
get_template_definition(kname, "fft", tg_mem_size, in_type, out_type);
|
||||
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(n, 2);
|
||||
compute_encoder.set_bytes(total_batch_size, 3);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fft_four_step_inplace(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Also prepare the intermediate array for the four-step fft which is
|
||||
// implemented with 2 kernel calls.
|
||||
array intermediate(
|
||||
(real && inverse) ? out.shape() : in.shape(), complex64, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
prepare_output_array(in, intermediate, axis);
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Make the two calls
|
||||
for (int step = 0; step < 2; step++) {
|
||||
// Create the parameters
|
||||
int n1 = plan.first_size();
|
||||
int n2 = plan.second_size();
|
||||
int n = (step == 0) ? n1 : n2;
|
||||
bool power_of_2 = true;
|
||||
int total_batch_size =
|
||||
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
||||
auto& steps = (step == 0) ? plan.first_steps() : plan.second_steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size =
|
||||
std::max(MIN_THREADGROUP_MEM_SIZE / n, MIN_COALESCE_WIDTH);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto to_type = [](const array& x) {
|
||||
return x.dtype() == float32 ? "float" : "float2";
|
||||
};
|
||||
auto in_type = step == 0 ? to_type(in) : to_type(intermediate);
|
||||
auto out_type = step == 0 ? to_type(intermediate) : to_type(out);
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"four_step_mem_",
|
||||
tg_mem_size,
|
||||
"_",
|
||||
in_type,
|
||||
"_",
|
||||
out_type,
|
||||
"_",
|
||||
step,
|
||||
(real ? "_true" : "_false"));
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def = get_template_definition(
|
||||
kname, "four_step_fft", tg_mem_size, in_type, out_type, step, real);
|
||||
auto kernel =
|
||||
get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array((step == 0) ? in : intermediate, 0);
|
||||
compute_encoder.set_output_array((step == 0) ? intermediate : out, 1);
|
||||
compute_encoder.set_bytes(n1, 2);
|
||||
compute_encoder.set_bytes(n2, 3);
|
||||
compute_encoder.set_bytes(total_batch_size, 4);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void fft_bluestein(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Prepare the arguments for bluestein fft
|
||||
int n = plan.bluestein_size();
|
||||
bool power_of_2 = true;
|
||||
int total_batch_size = out.dtype() == float32 ? out.size() / plan.size()
|
||||
: in.size() / plan.size();
|
||||
auto& steps = plan.steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
kname, "bluestein_fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def = get_template_definition(
|
||||
kname, "bluestein_fft", tg_mem_size, in_type, out_type);
|
||||
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Get the bluestein constants
|
||||
auto [w_k, w_q] =
|
||||
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||
d.add_temporary(w_k, s.index);
|
||||
d.add_temporary(w_q, s.index);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_input_array(w_q, 2);
|
||||
compute_encoder.set_input_array(w_k, 3);
|
||||
compute_encoder.set_bytes(plan.size(), 4);
|
||||
compute_encoder.set_bytes(n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fft_multi_upload_bluestein(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Get Bluestein's constants using the CPU (this is done in the submission
|
||||
// thread which is pretty bad).
|
||||
auto [w_k, w_q] =
|
||||
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||
d.add_temporary(w_k, s.index);
|
||||
d.add_temporary(w_q, s.index);
|
||||
|
||||
// Prepare the input
|
||||
auto in_shape = inverse ? out.shape() : in_.shape();
|
||||
array in(std::move(in_shape), complex64, nullptr, {});
|
||||
if (real && !inverse) {
|
||||
copy_gpu(
|
||||
in_,
|
||||
in,
|
||||
in_.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
d.add_temporary(in, s.index);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = plan.size() % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in_}, conj_temp, "Conjugate", s);
|
||||
slice_gpu(in_, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, in, (int)axis, s);
|
||||
d.add_temporary(conj_temp, s.index);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in_}, in, "Conjugate", s);
|
||||
d.add_temporary(in, s.index);
|
||||
} else {
|
||||
in.copy_shared_buffer(in_);
|
||||
}
|
||||
|
||||
// Multiply with
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast(in.shape(), complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
array x(in.shape(), complex64, nullptr, {});
|
||||
binary_op_gpu({in, w_k_broadcast}, x, "Multiply", s);
|
||||
d.add_temporary(x, s.index);
|
||||
|
||||
// Pad
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_size();
|
||||
array padded_x(padded_shape, complex64, nullptr, {});
|
||||
auto zero = array(complex64_t{0.0f, 0.0f});
|
||||
pad_gpu(x, zero, padded_x, {(int)axis}, {0}, s);
|
||||
d.add_temporary(zero, s.index);
|
||||
d.add_temporary(padded_x, s.index);
|
||||
|
||||
// First fft
|
||||
}
|
||||
|
||||
void fft_op_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
|
||||
// Get the FFT size and plan it
|
||||
auto plan =
|
||||
FFTPlan(out.dtype() == float32 ? out.shape(axis) : in.shape(axis));
|
||||
|
||||
switch (plan.type()) {
|
||||
case FFTPlan::NOOP:
|
||||
std::cout << "--------------> 1-size FFT <-----------------" << std::endl;
|
||||
break;
|
||||
case FFTPlan::STOCKHAM:
|
||||
return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::SMALL_FOUR_STEP:
|
||||
return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::BLUESTEIN:
|
||||
return fft_bluestein(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::UNSUPPORTED: {
|
||||
std::string msg;
|
||||
concatenate(msg, "FFT of size ", plan.size(), " not supported");
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
default:
|
||||
std::cout << "----- NYI ----" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void nd_fft_op(
|
||||
void nd_fft_op_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
auto temp_shape = inverse ? in.shape() : out.shape();
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
int index = inverse ? reverse_index : i;
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
// We are going to make and possibly reuse some intermediate arrays that will
|
||||
// hold the intermediate fft results.
|
||||
auto shape = inverse ? in.shape() : out.shape();
|
||||
std::vector<array> intermediates;
|
||||
intermediates.reserve(2);
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
d.add_temporaries(std::move(temp_arrs), s.index);
|
||||
// Utility to return either in or one of the intermediates.
|
||||
auto get_input_array = [&](int step) -> const array& {
|
||||
// The first step so use the input array
|
||||
if (step == 0) {
|
||||
return in;
|
||||
}
|
||||
|
||||
return intermediates[(step - 1) % 2];
|
||||
};
|
||||
|
||||
// Utility to return either out or one of the intermediates. It also informs
|
||||
// us if we should allocate memory for that output or there is already some
|
||||
// allocated.
|
||||
auto get_output_array = [&](int step) -> array& {
|
||||
// It is the final step so return the output array
|
||||
if (step == axes.size() - 1) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// We already have made an array that we can use so go ahead and use it and
|
||||
// don't reallocate the memory.
|
||||
if (step % 2 < intermediates.size()) {
|
||||
return intermediates[step % 2];
|
||||
}
|
||||
|
||||
array x(shape, complex64, nullptr, {});
|
||||
x.set_data(allocator::malloc(x.nbytes()));
|
||||
intermediates.emplace_back(std::move(x));
|
||||
d.add_temporary(intermediates.back(), s.index);
|
||||
|
||||
return intermediates.back();
|
||||
};
|
||||
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
for (int step = 0; step < axes.size(); step++) {
|
||||
auto x = get_input_array(step);
|
||||
auto y = get_output_array(step);
|
||||
auto step_axis = axes[inverse ? step : axes.size() - step - 1];
|
||||
auto step_real = real && (inverse ? step == axes.size() - 1 : step == 0);
|
||||
fft_op_inplace(x, y, step_axis, inverse, step_real, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// The FFT ops above have the *_inplace suffix. This means that the memory
|
||||
// needs to be already allocated in the output array. Similar to
|
||||
// copy_gpu_inplace and so on.
|
||||
//
|
||||
// Even though we allocate the memory, we do not necessarily want the
|
||||
// contiguous strides so the *_inplace ops may change the strides and flags
|
||||
// of the array but will not reallocate the memory.
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
if (axes_.size() > 1) {
|
||||
nd_fft_op(in, out, axes_, inverse_, real_, s);
|
||||
nd_fft_op_inplace(in, out, axes_, inverse_, real_, d, s);
|
||||
} else {
|
||||
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
|
||||
fft_op_inplace(in, out, axes_[0], inverse_, real_, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,11 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@@ -15,7 +13,6 @@
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
|
||||
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
|
||||
|
||||
std::string gen_hadamard_codelet(int m) {
|
||||
// Generate a O(m^2) hadamard codelet for a given M
|
||||
@@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) {
|
||||
return source.str();
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
void hadamard_mn_contiguous(
|
||||
const array& x,
|
||||
array& y,
|
||||
int m,
|
||||
int n1,
|
||||
int n2,
|
||||
float scale,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
int n = n1 * n2;
|
||||
int read_width_n1 = n1 == 2 ? 2 : 4;
|
||||
int read_width_n2 = n2 == 2 ? 2 : 4;
|
||||
int read_width_m = (n == 2 || m == 28) ? 2 : 4;
|
||||
int max_radix_1 = std::min(n1, 16);
|
||||
int max_radix_2 = std::min(n2, 16);
|
||||
float scale_n1 = 1.0;
|
||||
float scale_n2 = (m == 1) ? scale : 1.0;
|
||||
float scale_m = scale;
|
||||
|
||||
auto& in = inputs[0];
|
||||
// n2 is a row contiguous power of 2 hadamard transform
|
||||
MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);
|
||||
MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);
|
||||
|
||||
std::vector<array> copies;
|
||||
// Only support the last axis for now
|
||||
int axis = in.ndim() - 1;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
// TODO(alexbarron) pass strides to kernel to relax this constraint
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
// n1 is a strided power of 2 hadamard transform with stride n2
|
||||
MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);
|
||||
MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);
|
||||
|
||||
// m is a strided hadamard transform with stride n = n1 * n2
|
||||
MTL::Size group_dims_m(
|
||||
std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);
|
||||
MTL::Size grid_dims_m(
|
||||
group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);
|
||||
|
||||
// Make the kernel
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(kname, "hadamard_", n * m, "_", type_to_name(x));
|
||||
auto lib = d.get_library(kname, [&]() {
|
||||
std::string kernel;
|
||||
concatenate(
|
||||
kernel,
|
||||
metal::utils(),
|
||||
gen_hadamard_codelet(m),
|
||||
metal::hadamard(),
|
||||
get_template_definition(
|
||||
"n2" + kname,
|
||||
"hadamard_n",
|
||||
get_type_string(x.dtype()),
|
||||
n2,
|
||||
max_radix_2,
|
||||
read_width_n2));
|
||||
if (n1 > 1) {
|
||||
kernel += get_template_definition(
|
||||
"n1" + kname,
|
||||
"hadamard_n",
|
||||
get_type_string(x.dtype()),
|
||||
n1,
|
||||
max_radix_1,
|
||||
read_width_n1,
|
||||
n2);
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.copy_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
int n, m;
|
||||
std::tie(n, m) = decompose_hadamard(in.shape(axis));
|
||||
|
||||
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
|
||||
}
|
||||
|
||||
int max_radix = std::min(n, 16);
|
||||
// Use read_width 2 for m = 28 to avoid register spilling
|
||||
int read_width = (n == 2 || m == 28) ? 2 : 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "hadamard_" << n * m << "_" << type_to_name(out);
|
||||
auto kernel_name = kname.str();
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
auto codelet = gen_hadamard_codelet(m);
|
||||
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||
kernel_source << get_template_definition(
|
||||
"n" + kernel_name,
|
||||
"hadamard_n",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
max_radix,
|
||||
read_width);
|
||||
kernel_source << get_template_definition(
|
||||
"m" + kernel_name,
|
||||
"hadamard_m",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width);
|
||||
return kernel_source.str();
|
||||
if (m > 1) {
|
||||
kernel += get_template_definition(
|
||||
"m" + kname,
|
||||
"hadamard_m",
|
||||
get_type_string(x.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width_m);
|
||||
}
|
||||
return kernel;
|
||||
});
|
||||
|
||||
int batch_size = in.size() / n;
|
||||
int threads_per = n / max_radix;
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
auto launch_hadamard = [&](const array& in,
|
||||
array& out,
|
||||
const std::string& kernel_name,
|
||||
float scale) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Launch the strided transform for n1
|
||||
if (n1 > 1) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel("n1" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(scale, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
};
|
||||
|
||||
if (m > 1) {
|
||||
// When m is greater than 1, we decompose the
|
||||
// computation into two uploads to the GPU:
|
||||
//
|
||||
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
|
||||
//
|
||||
// y = h48 @ x
|
||||
//
|
||||
// Upload 1:
|
||||
// tmp = a.reshape(12, 4) @ h4
|
||||
//
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||
|
||||
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||
batch_size = in.size() / m / read_width / threads_per;
|
||||
launch_hadamard(temp, out, "m" + kernel_name, scale_);
|
||||
} else {
|
||||
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_n1, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
// Launch the transform for n2
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel("n2" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(n1 > 1 ? y : x, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_n2, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);
|
||||
|
||||
// Launch the strided transform for m
|
||||
if (m > 1) {
|
||||
auto kernel = d.get_kernel("m" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(y, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_m, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Split the hadamard transform so that all of them work on vectors smaller
|
||||
// than 8192 elements.
|
||||
//
|
||||
// We decompose it in the following way:
|
||||
//
|
||||
// n = m * n1 * n2 = m * 2^k1 * 2^k2
|
||||
//
|
||||
// where m is in (1, 12, 20, 28) and n1 and n2 <= 8192
|
||||
auto [n, m] = decompose_hadamard(in.shape().back());
|
||||
int n1 = 1, n2 = n;
|
||||
if (n > 8192) {
|
||||
for (n2 = 2; n2 * n2 < n; n2 *= 2) {
|
||||
}
|
||||
n1 = n / n2;
|
||||
}
|
||||
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,7 +2,7 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/indexing.h"
|
||||
|
@@ -9,64 +9,85 @@ template <typename T, typename U, typename Op>
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[0]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[0], b[offset]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[0]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[offset]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
|
@@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
|
@@ -130,6 +130,24 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
|
||||
metal::isnan(y.imag)) {
|
||||
return metal::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
constexpr float inf = metal::numeric_limits<float>::infinity();
|
||||
complex64_t maxval = x > y ? x : y;
|
||||
complex64_t minval = x < y ? x : y;
|
||||
if (minval.real == -inf || maxval.real == inf)
|
||||
return maxval;
|
||||
float m = metal::exp(minval.real - maxval.real);
|
||||
complex64_t dexp{
|
||||
m * metal::cos(minval.imag - maxval.imag),
|
||||
m * metal::sin(minval.imag - maxval.imag),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
@@ -12,82 +12,103 @@ template <typename T, typename U, typename Op>
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[0], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[0]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[0]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_sv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[0], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vs2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[0]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[0]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void binary_vv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
|
@@ -1,39 +1,53 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void copy_s2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[0]);
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void copy_v2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[offset]);
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
|
@@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so
|
||||
read/write performance is important.
|
||||
|
||||
Where possible, we read 128 bits sequentially in each thread,
|
||||
coalesced with accesses from adajcent threads for optimal performance.
|
||||
coalesced with accesses from adjacent threads for optimal performance.
|
||||
|
||||
We implement specialized reading/writing for:
|
||||
- FFT
|
||||
|
@@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N, int max_radix, int read_width>
|
||||
template <typename T, int N, int max_radix, int read_width, int stride = 1>
|
||||
[[kernel]] void hadamard_n(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@@ -46,18 +46,25 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
constexpr short logFinal = logN % logR;
|
||||
constexpr short final_radix = 1 << (logFinal);
|
||||
|
||||
int batch_idx = elem.x * N;
|
||||
short i = elem.y;
|
||||
int batch_idx = elem.y * N * stride + elem.z;
|
||||
short i = elem.x;
|
||||
|
||||
threadgroup T buf[N];
|
||||
|
||||
// Read values from device
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
if (stride == 1) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
buf[index + r] = in[batch_idx + index + r];
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
buf[index + r] = in[batch_idx + index + r];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix; j++) {
|
||||
buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,12 +120,20 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
}
|
||||
|
||||
// Write values to device
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
if (stride == 1) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + index + r] = T(buf[index + r] * scale);
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + index + r] = T(buf[index + r] * scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix; j++) {
|
||||
out[batch_idx + (j * num_threads + i) * stride] =
|
||||
buf[j * num_threads + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl(
|
||||
|
||||
auto wl = (const device uint8_t*)w;
|
||||
|
||||
x += y_row * K;
|
||||
x += y_row * static_cast<int64_t>(K);
|
||||
wl += y_col * K_w;
|
||||
scales += y_col * K_g;
|
||||
biases += y_col * K_g;
|
||||
y += y_row * N + y_col;
|
||||
y += y_row * static_cast<int64_t>(N) + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
@@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl(
|
||||
// Set the block
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
x += y_row * K;
|
||||
x += y_row * static_cast<int64_t>(K);
|
||||
wl += y_col * bytes_per_pack / pack_factor;
|
||||
scales += y_col / group_size;
|
||||
biases += y_col / group_size;
|
||||
y += y_row * N + y_col;
|
||||
y += y_row * static_cast<int64_t>(N) + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
|
@@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
||||
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on
|
||||
|
@@ -56,9 +56,9 @@ template <typename T, int D, int V = D>
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
||||
simd_lid * qk_per_thread;
|
||||
@@ -213,9 +213,9 @@ template <typename T, int D, int V = D>
|
||||
const int block_idx = tid.z;
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||
const int o_offset = head_idx * tpg.y + q_seq_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
@@ -358,8 +358,8 @@ template <typename T, int D>
|
||||
// Adjust positions
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int n_heads = tpg.x;
|
||||
const int q_offset = n_heads * q_seq_idx + head_idx;
|
||||
const int q_offset = head_idx * tpg.y + q_seq_idx;
|
||||
;
|
||||
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||
sums += q_offset * blocks;
|
||||
maxs += q_offset * blocks;
|
||||
|
@@ -95,7 +95,7 @@ template <
|
||||
|
||||
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||
tidl.y * params->Q_strides[1] + // Head
|
||||
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
|
||||
tidl.x * BQ * params->Q_strides[2]; // Sequence
|
||||
|
||||
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
|
||||
K += tidl.z * params->K_strides[0] + // Batch
|
||||
@@ -106,7 +106,7 @@ template <
|
||||
|
||||
O += tidl.z * params->O_strides[0] + // Batch
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||
tidl.x * BQ * params->O_strides[2]; // Sequence
|
||||
|
||||
if (has_mask) {
|
||||
mask += tidl.z * mask_params->M_strides[0] + // Batch
|
||||
|
@@ -113,7 +113,7 @@ struct BlockLoader {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
// Zero out unneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
@@ -240,7 +240,7 @@ struct BlockLoaderT {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
// Zero out unneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
|
@@ -141,7 +141,7 @@ implicit_gemm_conv_2d_general(
|
||||
|
||||
// Store results to device memory
|
||||
{
|
||||
// Adjust for simdgroup and thread locatio
|
||||
// Adjust for simdgroup and thread location
|
||||
int offset_m = c_row + mma_op.sm;
|
||||
int offset_n = c_col + mma_op.sn;
|
||||
C += offset_n;
|
||||
|
@@ -113,7 +113,7 @@ struct BlockLoader {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
// Zero out unneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
|
@@ -1,25 +1,32 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void ternary_v(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
d[index] = Op()(a[index], b[index], c[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void ternary_v2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
d[offset] = Op()(a[offset], b[offset], c[offset]);
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename IdxT = int64_t>
|
||||
|
@@ -1,21 +1,28 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void unary_v(
|
||||
device const T* in,
|
||||
device U* out,
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = Op()(in[index]);
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
[[kernel]] void unary_v2(
|
||||
device const T* in,
|
||||
device U* out,
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
out[offset] = Op()(in[offset]);
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
|
@@ -77,6 +77,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log1p, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
|
@@ -15,6 +15,14 @@
|
||||
|
||||
typedef half float16_t;
|
||||
|
||||
// Work per thread values for different types. The values here are expected to
|
||||
// match get_work_per_thread in mlx/backend/metal/utils.h
|
||||
template <typename U>
|
||||
struct WorkPerThread {
|
||||
static_assert(sizeof(U) <= 8, "Type too large");
|
||||
static constexpr int constant n = 8 / sizeof(U);
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -328,6 +336,23 @@ inline bfloat16_t log1p(bfloat16_t x) {
|
||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
}
|
||||
|
||||
inline complex64_t log1p(complex64_t in) {
|
||||
float x = in.real;
|
||||
float y = in.imag;
|
||||
float zabs = metal::precise::sqrt(x * x + y * y);
|
||||
float theta = metal::atan2(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1p(r), theta};
|
||||
} else {
|
||||
auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {metal::log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SIMD shuffle ops
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
|
@@ -1,11 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <memory>
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@@ -13,85 +13,6 @@ bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||
std::ostringstream msg;
|
||||
msg << "[METAL] Command buffer execution failed: "
|
||||
<< cbuf->error()->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
|
||||
if (d.command_buffer_needs_commit(s.index)) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
check_error(cbuf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
d.end_encoding(s.index);
|
||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
cb->retain();
|
||||
d.end_encoding(s.index);
|
||||
d.commit_command_buffer(s.index);
|
||||
cb->waitUntilCompleted();
|
||||
check_error(cb);
|
||||
cb->release();
|
||||
}
|
||||
|
||||
void start_capture(std::string path, id object) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
@@ -128,4 +49,36 @@ void stop_capture() {
|
||||
manager->stopCapture();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||
device_info() {
|
||||
auto init_device_info = []()
|
||||
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto name = std::string(raw_device->name()->utf8String());
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||
|
||||
size_t rsrc_limit = 0;
|
||||
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
||||
if (rsrc_limit == 0) {
|
||||
rsrc_limit = 499000;
|
||||
}
|
||||
|
||||
return {
|
||||
{"device_name", name},
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize},
|
||||
{"resource_limit", rsrc_limit}};
|
||||
};
|
||||
static auto device_info_ = init_device_info();
|
||||
return device_info_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -2,11 +2,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
/* Check if the Metal backend is available. */
|
||||
|
22
mlx/backend/metal/no_metal.cpp
Normal file
22
mlx/backend/metal/no_metal.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
void start_capture(std::string) {}
|
||||
void stop_capture() {}
|
||||
|
||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||
device_info() {
|
||||
throw std::runtime_error(
|
||||
"[metal::device_info] Cannot get device info without metal backend");
|
||||
};
|
||||
|
||||
} // namespace mlx::core::metal
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
|
@@ -7,10 +7,10 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
@@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
enc.set_bytes(step, 1);
|
||||
}
|
||||
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
static array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
@@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
CopyType ctype =
|
||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(inputs[0], out, ctype);
|
||||
}
|
||||
|
||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
concatenate_gpu(inputs, out, axis_, stream());
|
||||
}
|
||||
|
||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Depends::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = inputs[0];
|
||||
CopyType ctype;
|
||||
if (in.data_size() == 1) {
|
||||
ctype = CopyType::Scalar;
|
||||
} else if (in.flags().contiguous) {
|
||||
ctype = CopyType::Vector;
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
auto& val = inputs[1];
|
||||
|
||||
// Padding value must be a scalar
|
||||
assert(val.size() == 1);
|
||||
|
||||
// Padding value, input and output must be of the same type
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
|
||||
}
|
||||
|
||||
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
@@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
@@ -492,18 +357,6 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void QRF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@@ -537,35 +390,4 @@ void LUF::eval_gpu(
|
||||
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
||||
}
|
||||
|
||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
|
@@ -3,7 +3,7 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
|
@@ -2,7 +2,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
@@ -154,9 +154,9 @@ void sdpa_vector(
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_head_stride = k.strides()[1];
|
||||
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
|
||||
size_t k_seq_stride = k.strides()[2];
|
||||
size_t v_head_stride = v.strides()[1];
|
||||
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
|
||||
size_t v_seq_stride = v.strides()[2];
|
||||
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
@@ -199,11 +199,10 @@ void sdpa_vector(
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 11 + float_mask);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
|
||||
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
|
||||
int32_t head_stride =
|
||||
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 13);
|
||||
compute_encoder.set_bytes(q_seq_stride, 14);
|
||||
compute_encoder.set_bytes(head_stride, 15);
|
||||
@@ -238,9 +237,10 @@ void sdpa_vector_2pass(
|
||||
int N = k.shape(2);
|
||||
int blocks = 32;
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_head_stride = k.strides()[1];
|
||||
|
||||
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
|
||||
size_t k_seq_stride = k.strides()[2];
|
||||
size_t v_head_stride = v.strides()[1];
|
||||
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
|
||||
size_t v_seq_stride = v.strides()[2];
|
||||
MTL::Size group_dims(8 * 32, 1, 1);
|
||||
MTL::Size grid_dims(B, q.shape(2), blocks);
|
||||
@@ -302,11 +302,10 @@ void sdpa_vector_2pass(
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 13 + float_mask);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
|
||||
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
|
||||
int32_t head_stride =
|
||||
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 15);
|
||||
compute_encoder.set_bytes(q_seq_stride, 16);
|
||||
compute_encoder.set_bytes(head_stride, 17);
|
||||
@@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
};
|
||||
|
||||
// Checks if arr is row contiguous or the sequence and head dimension are
|
||||
// transposed
|
||||
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return true;
|
||||
}
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
|
||||
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
|
||||
};
|
||||
|
||||
// Checks that the headdim dimension has stride 1.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
return arr.strides(-1) == 1;
|
||||
@@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) <= 8) {
|
||||
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return true;
|
||||
}
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
if (shape[0] == 1 || shape[1] == 1) {
|
||||
// If either the batch or head dimension is a singleton, the other can
|
||||
// be transposed with the sequence dimension
|
||||
auto bidx = shape[0] == 1 ? 1 : 0;
|
||||
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
|
||||
(strides[bidx] == shape[3]);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto kv_copy_unless = [](const array& arr) {
|
||||
// keys and values should be copied if:
|
||||
// - the last dimension is not contiguous
|
||||
// - the batch and head dim are not contiguous
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
if (strides.back() != 1) {
|
||||
return false;
|
||||
}
|
||||
if (shape[0] == 1 || shape[1] == 1) {
|
||||
return true;
|
||||
}
|
||||
return (strides[0] == strides[1] * shape[1]);
|
||||
};
|
||||
|
||||
const auto& q = copy_unless(q_copy_unless, q_pre);
|
||||
const auto& k = copy_unless(kv_copy_unless, k_pre);
|
||||
const auto& v = copy_unless(kv_copy_unless, v_pre);
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
|
||||
q.size() == o.size()) {
|
||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||
o.copy_shared_buffer(q);
|
||||
} else {
|
||||
if (o.shape(2) == 1) {
|
||||
o.set_data(allocator::malloc(o.nbytes()));
|
||||
} else {
|
||||
auto strides = o.strides();
|
||||
strides[2] = o.shape(1) * o.shape(3);
|
||||
strides[1] = o.shape(3);
|
||||
auto flags = q.flags();
|
||||
flags.row_contiguous = q.shape(1) == 1;
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
|
||||
}
|
||||
o.set_data(allocator::malloc(o.nbytes()));
|
||||
}
|
||||
|
||||
auto mask =
|
||||
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
|
||||
auto mask_copy_unless = [&q](const array& arr) {
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 ||
|
||||
(strides[0] == strides[1] * shape[1]);
|
||||
};
|
||||
|
||||
auto mask = inputs.size() > 3
|
||||
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
|
||||
: std::nullopt;
|
||||
|
||||
// We route to the 2 pass fused attention if
|
||||
// - The device is large and the sequence length long
|
||||
|
@@ -3,7 +3,7 @@
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
@@ -2,21 +2,12 @@
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
void concatenate_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
@@ -48,30 +39,4 @@ void concatenate_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s) {
|
||||
// Fill output with val
|
||||
fill_gpu(val, out, s);
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
@@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > INT32_MAX;
|
||||
work_per_thread = 1;
|
||||
work_per_thread = get_work_per_thread(b.dtype());
|
||||
}
|
||||
std::string kernel_name;
|
||||
if (topt == TernaryOpType::General) {
|
||||
@@ -106,13 +106,19 @@ void ternary_op_gpu_inplace(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 4);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(out.data_size(), 4);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
@@ -34,18 +34,19 @@ void unary_op_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides] = maybe_collapse();
|
||||
int ndim = shape.size();
|
||||
size_t nthreads = contig ? in.data_size() : in.size();
|
||||
bool large;
|
||||
if (!contig) {
|
||||
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
||||
} else {
|
||||
large = in.data_size() > UINT32_MAX;
|
||||
}
|
||||
int work_per_thread = !contig && large ? 4 : 1;
|
||||
int work_per_thread;
|
||||
std::string kernel_name;
|
||||
if (contig) {
|
||||
work_per_thread = get_work_per_thread(in.dtype());
|
||||
kernel_name = (large ? "v2" : "v");
|
||||
} else {
|
||||
work_per_thread = large ? 4 : 1;
|
||||
kernel_name = "gn" + std::to_string(work_per_thread);
|
||||
if (large) {
|
||||
kernel_name += "large";
|
||||
@@ -75,12 +76,20 @@ void unary_op_gpu_inplace(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = ceildiv(in.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims;
|
||||
if (large) {
|
||||
compute_encoder.set_bytes<int64_t>(in.data_size(), 2);
|
||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||
} else {
|
||||
compute_encoder.set_bytes<int>(in.data_size(), 2);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
}
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
@@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) {
|
||||
concatenate(acc, args...);
|
||||
}
|
||||
|
||||
inline int get_work_per_thread(Dtype dtype) {
|
||||
return std::max(1, 8 / dtype.size());
|
||||
}
|
||||
|
||||
inline size_t ceildiv(size_t n, size_t m) {
|
||||
return (n + m - 1) / m;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,6 +1,7 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user