Merge branch 'main' into metal-thread-safe

This commit is contained in:
acsweet 2025-05-27 09:40:36 -07:00 committed by GitHub
commit 992eac905a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
117 changed files with 4160 additions and 1177 deletions

View File

@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@ -83,6 +84,10 @@ if(MLX_BUILD_METAL)
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
@ -226,6 +231,9 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git

View File

@ -11,13 +11,14 @@ include(CMakeParseArguments)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers)
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
#
# clang format on
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
@ -26,6 +27,10 @@ macro(mlx_build_metallib)
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
# Prepare metallib build command
add_custom_command(

View File

@ -10,7 +10,7 @@ import mlx.core as mx
# -- Project information -----------------------------------------------------
project = "MLX"
copyright = "2023, MLX Contributors"
copyright = "2023, Apple"
author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3])
release = version

View File

@ -19,6 +19,8 @@ Array
array.ndim
array.shape
array.size
array.real
array.imag
array.abs
array.all
array.any

View File

@ -16,6 +16,8 @@ Linear Algebra
cross
qr
svd
eigvals
eig
eigvalsh
eigh
lu

View File

@ -21,7 +21,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
@ -51,5 +51,14 @@ if(MLX_BUILD_METAL)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
endif()

View File

@ -224,6 +224,10 @@ class array {
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
Data(Data&& o) : buffer(o.buffer), d(o.d) {
o.buffer = allocator::Buffer(nullptr);
o.d = [](allocator::Buffer) {};
}
~Data() {
d(buffer);
}

View File

@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
if (n > (1 << 26)) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where k <= 26");
}
return {n, m};
}
} // namespace mlx::core
} // namespace mlx::core

View File

@ -165,4 +165,11 @@ void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}
} // namespace mlx::core

View File

@ -46,6 +46,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp

View File

@ -14,10 +14,8 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
Strides strides = remove_index(in.strides(), axis);
Shape shape = remove_index(in.shape(), axis);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();

View File

@ -22,7 +22,8 @@ void slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@ -60,7 +61,8 @@ void slow_conv_1D(
out_stride_O = out.strides()[2],
flip,
padding = padding[0],
padding_lo = padding_lo[0],
padding_hi = padding_hi[0],
wt_stride = wt_strides[0],
wt_dilation = wt_dilation[0],
in_dilation = in_dilation[0]]() mutable {
@ -77,7 +79,7 @@ void slow_conv_1D(
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_stride - padding + wh_flip * wt_dilation;
int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
auto ih_div = std::div(ih, in_dilation);
@ -109,7 +111,8 @@ void slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@ -120,230 +123,235 @@ void slow_conv_2D(
encoder.set_input_array(wt);
encoder.set_output_array(out);
encoder.dispatch([st_wt_ptr = wt.data<T>(),
st_in_ptr = in.data<T>(),
st_out_ptr = out.data<T>(),
encoder.dispatch(
[st_wt_ptr = wt.data<T>(),
st_in_ptr = in.data<T>(),
st_out_ptr = out.data<T>(),
N = in.shape(
0), // Batch size, should be the same as out.shape(0)
iH = 1 +
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
iW = 1 +
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
C = in.shape(3), // In channels
oH = out.shape(1), // Output spatial dim
oW = out.shape(2), // Output spatial dim
O = wt.shape(0), // Out channels
wH = wt.shape(1), // Weight spatial dim
wW = wt.shape(2), // Weight spatial dim
N = in.shape(0), // Batch size, should be the same as out.shape(0)
iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
C = in.shape(3), // In channels
oH = out.shape(1), // Output spatial dim
oW = out.shape(2), // Output spatial dim
O = wt.shape(0), // Out channels
wH = wt.shape(1), // Weight spatial dim
wW = wt.shape(2), // Weight spatial dim
groups = in.shape(3) / wt.shape(3),
C_per_group = wt.shape(3),
groups = in.shape(3) / wt.shape(3),
C_per_group = wt.shape(3),
in_stride_N = in.strides()[0],
in_stride_H = in.strides()[1],
in_stride_W = in.strides()[2],
in_stride_C = in.strides()[3],
in_stride_N = in.strides()[0],
in_stride_H = in.strides()[1],
in_stride_W = in.strides()[2],
in_stride_C = in.strides()[3],
wt_stride_O = wt.strides()[0],
wt_stride_H = wt.strides()[1],
wt_stride_W = wt.strides()[2],
wt_stride_C = wt.strides()[3],
wt_stride_O = wt.strides()[0],
wt_stride_H = wt.strides()[1],
wt_stride_W = wt.strides()[2],
wt_stride_C = wt.strides()[3],
out_stride_N = out.strides()[0],
out_stride_H = out.strides()[1],
out_stride_W = out.strides()[2],
out_stride_O = out.strides()[3],
out_stride_N = out.strides()[0],
out_stride_H = out.strides()[1],
out_stride_W = out.strides()[2],
out_stride_O = out.strides()[3],
padding,
wt_strides,
wt_dilation,
in_dilation,
flip]() mutable {
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip]() mutable {
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
const int O_per_group = O / groups;
auto pt_conv_no_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int oh,
int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
const int O_per_group = O / groups;
auto pt_conv_no_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding_lo[0];
int iw_base = ow * wt_strides[1] - padding_lo[1];
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ww
} // wh
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h =
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w =
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_wgt_jump_h =
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w =
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
int f_out_jump_h =
std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w =
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
int ih_base = oh * wt_strides[0] - padding_lo[0];
int iw_base = ow * wt_strides[1] - padding_lo[1];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
iw_dil * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ih, iw check
} // ww
} // wh
} // ih, iw check
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oH_border_0 = 0;
int oH_border_1 = is_idil_one
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
: oH;
int oH_border_2 = std::max(
oH_border_1,
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
int oW_border_0 = 0;
int oW_border_1 = is_idil_one
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
: oW;
int oW_border_2 = std::max(
oW_border_1,
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
} // oh
// Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
} // n
});
} // n
});
}
template <typename T>
@ -351,7 +359,8 @@ void slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@ -400,7 +409,8 @@ void slow_conv_3D(
out_stride_H = out.strides()[2],
out_stride_W = out.strides()[3],
out_stride_O = out.strides()[4],
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -415,9 +425,9 @@ void slow_conv_3D(
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding_lo[2];
for (int o = 0; o < O; ++o) {
float r = 0.;
@ -478,7 +488,7 @@ void slow_conv_3D(
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding[0] + init_d;
int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
@ -490,7 +500,7 @@ void slow_conv_3D(
}
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
@ -502,7 +512,7 @@ void slow_conv_3D(
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
@ -521,9 +531,9 @@ void slow_conv_3D(
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding_lo[2];
int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h];
@ -573,24 +583,30 @@ void slow_conv_3D(
};
int oD_border_0 = 0;
int oD_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
int oD_border_1 = is_idil_one
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
: oD;
int oD_border_2 = std::max(
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
oD_border_1,
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD;
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
int oH_border_1 = is_idil_one
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
: oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
oH_border_1,
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
int oW_border_1 = is_idil_one
? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
: oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
oW_border_1,
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
@ -658,7 +674,8 @@ void dispatch_slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@ -669,7 +686,8 @@ void dispatch_slow_conv_1D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -680,7 +698,8 @@ void dispatch_slow_conv_1D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -691,7 +710,8 @@ void dispatch_slow_conv_1D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -707,7 +727,8 @@ void dispatch_slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@ -718,7 +739,8 @@ void dispatch_slow_conv_2D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -729,7 +751,8 @@ void dispatch_slow_conv_2D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -740,7 +763,8 @@ void dispatch_slow_conv_2D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -756,7 +780,8 @@ void dispatch_slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@ -767,7 +792,8 @@ void dispatch_slow_conv_3D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -778,7 +804,8 @@ void dispatch_slow_conv_3D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -789,7 +816,8 @@ void dispatch_slow_conv_3D(
in,
wt,
out,
padding,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu(
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {N, iH + 2 * padding[0], C};
Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding[0] * in_padded.strides()[1];
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu(
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset =
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const bool flip,
@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu(
Shape padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
}
padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {});
@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu(
// Pick input slice from padded
size_t data_offset = 0;
for (size_t i = 0; i < padding.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1];
for (size_t i = 0; i < padding_lo.size(); i++) {
data_offset += padding_lo[i] * in_padded.strides()[i + 1];
}
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
@ -1261,7 +1297,8 @@ void conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@ -1270,22 +1307,40 @@ void conv_1D_cpu(
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation, stream);
in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
}
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
}
return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
}
void conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@ -1295,18 +1350,35 @@ void conv_2D_cpu(
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
}
return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
}
void conv_3D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
@ -1317,11 +1389,28 @@ void conv_3D_cpu(
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
}
return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
}
} // namespace
@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,

174
mlx/backend/cpu/eig.cpp Normal file
View File

@ -0,0 +1,174 @@
// Copyright © 2025 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
void eig_impl(
array& a,
array& vectors,
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
OT* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>();
}
encoder.dispatch([a_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N;
}
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
});
encoder.add_temporary(a);
}
} // namespace
void Eig::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy(
a,
a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
values.set_data(allocator::malloc(values.nbytes()));
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.set_data(
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
}
switch (a.dtype()) {
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
}
}
} // namespace mlx::core

View File

@ -12,6 +12,133 @@ namespace mlx::core {
namespace {
template <typename T, class Enable = void>
struct EighWork {};
template <typename T>
struct EighWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
char jobz;
char uplo;
int N;
int lwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, T* values) {
syevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<int*>(buffers[1].buffer.raw_ptr()),
&liwork,
&info);
}
};
template <>
struct EighWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
char jobz;
char uplo;
int N;
int lwork;
int lrwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
T work;
R rwork;
int iwork;
heevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&rwork,
&lrwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work.real());
lrwork = static_cast<int>(rwork);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, R* values) {
heevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&lrwork,
static_cast<int*>(buffers[2].buffer.raw_ptr()),
&liwork,
&info);
if (jobz == 'V') {
// We have pre-transposed the vectors but we also must conjugate them
// when they are complex.
//
// We could vectorize this but it is so fast in comparison to heevd that
// it doesn't really matter.
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
*vectors = std::conj(*vectors);
vectors++;
}
}
}
}
};
template <typename T>
void eigh_impl(
array& vectors,
@ -19,8 +146,10 @@ void eigh_impl(
const std::string& uplo,
bool compute_eigenvectors,
Stream stream) {
using R = typename EighWork<T>::R;
auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<T>();
auto eig_ptr = values.data<R>();
char jobz = compute_eigenvectors ? 'V' : 'N';
auto& encoder = cpu::get_command_encoder(stream);
@ -33,49 +162,17 @@ void eigh_impl(
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
int lwork = -1;
int liwork = -1;
int info;
{
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
}
EighWork<T> work(jobz, uplo, N);
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
// Work loop
for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>(
&jobz,
&uplo,
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
work.run(vec_ptr, eig_ptr);
vec_ptr += N * N;
eig_ptr += N;
if (info != 0) {
if (work.info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
<< work.info;
throw std::runtime_error(msg.str());
}
}
@ -131,6 +228,10 @@ void Eigh::eval_cpu(
eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
case complex64:
eigh_impl<std::complex<float>>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error(
"[Eigh::eval_cpu] only supports float32 or float64.");

View File

@ -257,15 +257,11 @@ void gather_axis(
const array& ind,
array& out,
const int axis) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto shape = remove_index(ind.shape(), axis);
ContiguousIterator ind_it(
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
ContiguousIterator src_it(
shape, remove_index(src.strides(), axis), src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>();
@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
template <typename T, typename IdxT, typename OpT>
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto shape = remove_index(idx.shape(), axis);
ContiguousIterator idx_it(
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
ContiguousIterator upd_it(
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>();

View File

@ -2,14 +2,14 @@
#pragma once
// Required for Visual Studio.
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
#ifdef _MSC_VER
#include <complex>
#define LAPACK_COMPLEX_CUSTOM
#define lapack_complex_float std::complex<float>
#define lapack_complex_double std::complex<double>
#endif
#define lapack_complex_float_real(z) ((z).real())
#define lapack_complex_float_imag(z) ((z).imag())
#define lapack_complex_double_real(z) ((z).real())
#define lapack_complex_double_imag(z) ((z).imag())
#ifdef MLX_USE_ACCELERATE
#include <Accelerate/Accelerate.h>
@ -32,7 +32,7 @@
#endif
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
#define INSTANTIATE_LAPACK_REAL(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
@ -42,11 +42,24 @@
} \
}
INSTANTIATE_LAPACK_TYPES(geqrf)
INSTANTIATE_LAPACK_TYPES(orgqr)
INSTANTIATE_LAPACK_TYPES(syevd)
INSTANTIATE_LAPACK_TYPES(potrf)
INSTANTIATE_LAPACK_TYPES(gesvdx)
INSTANTIATE_LAPACK_TYPES(getrf)
INSTANTIATE_LAPACK_TYPES(getri)
INSTANTIATE_LAPACK_TYPES(trtri)
INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesvdx)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_COMPLEX(heevd)

View File

@ -132,6 +132,10 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return;
}
// Fill output with C
auto& c = inputs[2];
@ -139,7 +143,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
}

View File

@ -0,0 +1,55 @@
# Filename rules in cuda backend:
#
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
# * Device-only kernel code should be put in kernels/ subdir.
# * Files in kernels/ subdir should not include files outside.
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
# Enable defining device lambda functions.
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"70;80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
# Use fixed version of CCCL.
FetchContent_Declare(
cccl
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
nvtx3
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
GIT_TAG v3.1.1
GIT_SHALLOW TRUE
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(nvtx3)
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
# Make cuda runtime APIs available in non-cuda files.
find_package(CUDAToolkit REQUIRED)
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)

View File

@ -0,0 +1,154 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
#include <cassert>
namespace mlx::core {
namespace cu {
CudaAllocator::CudaAllocator() {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
}
Buffer CudaAllocator::malloc(size_t size) {
// TODO: Check memory limit.
auto* buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
std::lock_guard lock(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{buf};
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return;
}
// If free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([buffer]() { allocator().free(buffer); });
worker_->end_batch();
worker_->commit();
return;
}
}
size_t size = buf->size;
cudaFree(buf->data);
delete buf;
std::lock_guard lock(mutex_);
active_memory_ -= size;
}
size_t CudaAllocator::size(Buffer buffer) const {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return 0;
}
return buf->size;
}
void CudaAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
size_t CudaAllocator::get_active_memory() const {
return active_memory_;
}
size_t CudaAllocator::get_peak_memory() const {
return peak_memory_;
}
void CudaAllocator::reset_peak_memory() {
std::lock_guard lock(mutex_);
peak_memory_ = 0;
}
size_t CudaAllocator::get_memory_limit() {
return memory_limit_;
}
size_t CudaAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_);
return limit;
}
CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static CudaAllocator* allocator_ = new CudaAllocator;
return *allocator_;
}
} // namespace cu
namespace allocator {
Allocator& allocator() {
return cu::allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
}
} // namespace allocator
size_t get_active_memory() {
return cu::allocator().get_active_memory();
}
size_t get_peak_memory() {
return cu::allocator().get_peak_memory();
}
void reset_peak_memory() {
return cu::allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return cu::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return cu::allocator().get_memory_limit();
}
// TODO: Implement buffer cache.
size_t get_cache_memory() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@ -0,0 +1,58 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::cu {
class Worker;
using allocator::Buffer;
// Stores cuda-managed unified memory.
struct CudaBuffer {
void* data;
size_t size;
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
private:
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
};
CudaAllocator& allocator();
} // namespace mlx::core::cu

26
mlx/backend/cuda/copy.cpp Normal file
View File

@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& data_shape,
const Strides& strides_in_pre,
const Strides& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
}
void fill_gpu(const array& val, array& out, const Stream& s) {
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

117
mlx/backend/cuda/device.cpp Normal file
View File

@ -0,0 +1,117 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/metal/metal.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
void DeviceStream::synchronize() {
cudaStreamSynchronize(stream_);
}
cudaStream_t DeviceStream::schedule_cuda_stream() {
// TODO: Return a stream that maximizes parallelism.
return stream_;
}
cudaStream_t DeviceStream::last_cuda_stream() {
return stream_;
}
CommandEncoder& DeviceStream::get_encoder() {
if (!encoder_) {
encoder_ = std::make_unique<CommandEncoder>(*this);
}
return *encoder_;
}
Device::Device(int device) : device_(device) {
// Validate the requirements of device.
int attr = 0;
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
if (attr != 1) {
throw std::runtime_error(fmt::format(
"Device {} does not support synchronization in managed memory.",
device_));
}
}
void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host.
static int current = 0;
if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_;
}
}
DeviceStream& Device::get_stream(Stream s) {
auto it = streams_.find(s.index);
if (it == streams_.end()) {
it = streams_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CommandEncoder(DeviceStream& s)
: device_(s.device()), stream_(s) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
void CommandEncoder::end_encoding() {
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
// There is no kernel running, run completion handlers immediately.
if (!has_gpu_work_) {
worker_.consume_in_this_thread();
return;
}
has_gpu_work_ = false;
// Put completion handlers in a batch.
worker_.end_batch();
// Signaling kernel completion is expensive, delay until enough batches.
// TODO: This number is arbitrarily picked, profile for a better stragety.
if (worker_.uncommited_batches() > 8) {
commit();
}
}
void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream());
}
Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index);
if (it == devices.end()) {
it = devices.try_emplace(device.index, device.index).first;
}
return it->second;
}
DeviceStream& get_stream(Stream s) {
return device(s.device).get_stream(s);
}
CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder();
}
} // namespace cu
} // namespace mlx::core

131
mlx/backend/cuda/device.h Normal file
View File

@ -0,0 +1,131 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
#include <thrust/execution_policy.h>
#include <unordered_map>
namespace mlx::core::cu {
class Device;
class CommandEncoder;
class DeviceStream {
public:
explicit DeviceStream(Device& device);
DeviceStream(const DeviceStream&) = delete;
DeviceStream& operator=(const DeviceStream&) = delete;
// Wait until kernels in the stream complete.
void synchronize();
// Return a cuda stream for launching kernels.
cudaStream_t schedule_cuda_stream();
// Return the last cuda stream used.
cudaStream_t last_cuda_stream();
CommandEncoder& get_encoder();
Device& device() {
return device_;
}
private:
Device& device_;
CudaStream stream_;
std::unique_ptr<CommandEncoder> encoder_;
};
class Device {
public:
explicit Device(int device);
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls.
void make_current();
DeviceStream& get_stream(Stream s);
int cuda_device() const {
return device_;
}
private:
int device_;
std::unordered_map<int, DeviceStream> streams_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a cuda stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(cudaStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_cuda_error("kernel launch", cudaGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
};
Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s);
// Return an execution policy that does not sync for result.
// Note that not all thrust APIs support async policy, confirm before using.
inline auto thrust_policy(cudaStream_t stream) {
// TODO: Connect thrust's custom allocator with mlx's allocator.
return thrust::cuda::par_nosync.on(stream);
}
} // namespace mlx::core::cu

View File

@ -0,0 +1,35 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core {
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
using type = T;
};
template <>
struct CTypeToCudaType<float16_t> {
using type = __half;
};
template <>
struct CTypeToCudaType<bfloat16_t> {
using type = __nv_bfloat16;
};
template <>
struct CTypeToCudaType<complex64_t> {
using type = cuComplex;
};
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
} // namespace mlx::core

68
mlx/backend/cuda/eval.cpp Normal file
View File

@ -0,0 +1,68 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu {
bool is_available() {
return true;
}
void new_stream(Stream s) {
// Force initalization of cuda, so cuda runtime get destroyed at last.
cudaFree(nullptr);
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
}
void eval(array& arr) {
nvtx3::scoped_range r("gpu::eval");
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
if (encoder.has_gpu_work()) {
// Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
}
encoder.end_encoding();
}
void finalize(Stream s) {
nvtx3::scoped_range r("gpu::finalize");
cu::get_command_encoder(s).commit();
}
void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize");
cu::get_stream(s).synchronize();
}
} // namespace mlx::core::gpu

265
mlx/backend/cuda/event.cu Normal file
View File

@ -0,0 +1,265 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
///////////////////////////////////////////////////////////////////////////////
// CudaEvent implementations
///////////////////////////////////////////////////////////////////////////////
// Cuda event managed with RAII.
class CudaEventHandle {
public:
CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
}
~CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
}
CudaEventHandle(const CudaEventHandle&) = delete;
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
operator cudaEvent_t() const {
return event_;
}
private:
cudaEvent_t event_;
};
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaEventSynchronize(*event_);
}
void CudaEvent::wait(cudaStream_t stream) {
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
wait(cu::get_stream(s).last_cuda_stream());
}
}
void CudaEvent::record(cudaStream_t stream) {
cudaEventRecord(*event_, stream);
recorded_ = true;
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
record(cu::get_stream(s).last_cuda_stream());
}
}
bool CudaEvent::completed() const {
return cudaEventQuery(*event_) == cudaSuccess;
}
///////////////////////////////////////////////////////////////////////////////
// SharedEvent implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
uint64_t current;
while ((current = ac->load()) < value) {
ac->wait(current);
}
}
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
ac->store(value);
ac->notify_all();
}
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value);
}
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value);
}
} // namespace
SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory.
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
ptr->~Atomic();
allocator::free(buffer);
});
}
void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(ac_.get(), value);
}
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { wait(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(ac_.get(), value);
}
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { signal(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return ac_->load() >= value;
}
uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value");
return ac_->load();
}
} // namespace cu
///////////////////////////////////////////////////////////////////////////////
// Event implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have
// to fallback to SharedEvent in following cases:
// 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CudaEvent> cuda;
std::unique_ptr<cu::SharedEvent> shared;
bool is_created() const {
return cuda || shared;
}
void ensure_created(Stream s, uint64_t signal_value) {
if (is_created()) {
return;
}
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow SharedEvent");
shared = std::make_unique<cu::SharedEvent>();
} else {
cuda = std::make_unique<cu::CudaEvent>();
}
}
};
} // namespace
Event::Event(Stream s) : stream_(s) {
event_ = std::shared_ptr<void>(
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
}
void Event::wait() {
auto* event = static_cast<EventImpl*>(event_.get());
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait();
} else {
event->shared->wait(value());
}
}
void Event::wait(Stream s) {
auto* event = static_cast<EventImpl*>(event_.get());
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait(s);
} else {
event->shared->wait(s, value());
}
}
void Event::signal(Stream s) {
auto* event = static_cast<EventImpl*>(event_.get());
event->ensure_created(s, value());
if (event->cuda) {
assert(value() == 1);
event->cuda->record(s);
} else {
event->shared->signal(s, value());
}
}
bool Event::is_signaled() const {
auto* event = static_cast<EventImpl*>(event_.get());
if (!event->is_created()) {
return false;
}
if (event->cuda) {
assert(value() == 1);
return event->cuda->recorded() && event->cuda->completed();
} else {
return event->shared->is_signaled(value());
}
}
} // namespace mlx::core

66
mlx/backend/cuda/event.h Normal file
View File

@ -0,0 +1,66 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/stream.h"
#include <cuda_runtime.h>
#include <cuda/atomic>
#include <memory>
namespace mlx::core::cu {
class CudaEventHandle;
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
CudaEvent();
void wait();
void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called.
bool completed() const;
bool recorded() const {
return recorded_;
}
private:
bool recorded_{false};
std::shared_ptr<CudaEventHandle> event_;
};
// Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible.
class SharedEvent {
public:
using Atomic = cuda::atomic<uint64_t>;
SharedEvent();
void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value);
void wait(Stream s, uint64_t value);
void signal(uint64_t value);
void signal(cudaStream_t stream, uint64_t value);
void signal(Stream s, uint64_t value);
bool is_signaled(uint64_t value) const;
uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private:
std::shared_ptr<Atomic> ac_;
};
} // namespace mlx::core::cu

70
mlx/backend/cuda/fence.cu Normal file
View File

@ -0,0 +1,70 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
while (true) {
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
// it the load() may never return new value.
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
uint64_t current = ac->load();
if (current >= value) {
break;
}
}
}
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
busy_wait(ac, value);
}
} // namespace
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
// https://github.com/ml-explore/mlx/issues/2137
const auto& ac = fence->event.atomic();
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [ac, count = fence->count]() {
nvtx3::scoped_range r("Fence::wait()");
busy_wait(ac.get(), count);
});
} else {
nvtx3::scoped_range r("Fence::wait(s)");
auto& encoder = cu::get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
});
encoder.add_completed_handler([ac]() {});
encoder.end_encoding();
}
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace mlx::core::cu

View File

@ -0,0 +1,76 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
namespace mlx::core::cu {
///////////////////////////////////////////////////////////////////////////////
// Additional C++ operator overrides between half types and native types.
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U>
constexpr bool is_integral_except =
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
template <typename T, typename U>
constexpr bool is_arithmetic_except =
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
} \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
}
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
return HALF2FLOAT(x) OP static_cast<float>(y); \
} \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
return static_cast<float>(y) OP HALF2FLOAT(x); \
}
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
#undef MLX_DEFINE_HALF_OP
#undef MLX_DEFINE_HALF_CMP
} // namespace mlx::core::cu

View File

@ -0,0 +1,164 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/dtype_utils.cuh"
#include "mlx/backend/cuda/kernels/arange.cuh"
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cassert>
namespace mlx::core {
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_output_array(out);
encoder.launch_kernel([&, this](cudaStream_t stream) {
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
});
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(AddMM)
NO_GPU(ArcCos)
NO_GPU(ArcCosh)
NO_GPU(ArcSin)
NO_GPU(ArcSinh)
NO_GPU(ArcTan)
NO_GPU(ArcTan2)
NO_GPU(ArcTanh)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(BitwiseBinary)
NO_GPU(BitwiseInvert)
NO_GPU(BlockMaskedMM)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Conjugate)
NO_GPU(Convolution)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Floor)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Imag)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
NO_GPU(Multiply)
NO_GPU(Negative)
NO_GPU(NotEqual)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Real)
NO_GPU(Reduce)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)
NO_GPU(Sinh)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/slicing.h"
namespace mlx::core {
void concatenate_gpu(
const std::vector<array>& inputs,
array& out,
int axis,
const Stream& s) {
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

View File

@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/device.h"
#include <fmt/format.h>
namespace mlx::core {
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
}
void check_cuda_error(const char* name, cudaError_t err) {
if (err != cudaSuccess) {
throw std::runtime_error(
fmt::format("{} failed: {}", name, cudaGetErrorString(err)));
}
}
} // namespace mlx::core

36
mlx/backend/cuda/utils.h Normal file
View File

@ -0,0 +1,36 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_runtime.h>
namespace mlx::core {
namespace cu {
class Device;
}
// Cuda stream managed with RAII.
class CudaStream {
public:
explicit CudaStream(cu::Device& device);
~CudaStream();
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const {
return stream_;
}
private:
cudaStream_t stream_;
};
// Throw exception if the cuda API does not succeed.
void check_cuda_error(const char* name, cudaError_t err);
// The macro version that prints the command that failed.
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
} // namespace mlx::core

View File

@ -0,0 +1,90 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)),
worker_(&Worker::thread_fn, this) {}
Worker::~Worker() {
{
std::lock_guard lock(worker_mutex_);
stop_ = true;
}
worker_event_.signal(batch_ + 1);
worker_.join();
}
void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task));
}
void Worker::consume_in_this_thread() {
for (auto& task : pending_tasks_) {
task();
}
pending_tasks_.clear();
}
void Worker::end_batch() {
batch_++;
{
std::lock_guard lock(worker_mutex_);
worker_tasks_[batch_] = std::move(pending_tasks_);
}
uncommited_batches_++;
}
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
}
void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
// Signal the |worker_event_| in |signal_stream_| after the kernels in
// |stream_| finish running.
signal_event_.record(stream);
signal_event_.wait(signal_stream_);
worker_event_.signal(signal_stream_, batch_);
}
void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) {
uint64_t batch = worker_event_.value();
Tasks tasks;
{
std::lock_guard lock(worker_mutex_);
// Move tasks in signaled batches.
auto end = worker_tasks_.upper_bound(batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) {
tasks = std::move(it->second);
} else {
std::move(
it->second.begin(), it->second.end(), std::back_inserter(tasks));
}
}
worker_tasks_.erase(worker_tasks_.begin(), end);
}
for (auto& task : tasks) {
task();
}
worker_event_.wait(batch + 1);
}
}
} // namespace mlx::core::cu

68
mlx/backend/cuda/worker.h Normal file
View File

@ -0,0 +1,68 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include <functional>
#include <map>
#include <mutex>
#include <thread>
namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
public:
Worker();
~Worker();
Worker(const Worker&) = delete;
Worker& operator=(const Worker&) = delete;
// Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Put pending tasks in a batch.
void end_batch();
// Inform worker thread to run current batches now.
void commit();
// Inform worker thread to run current batches after kernels in |stream|
// finish running.
void commit(cudaStream_t stream);
// Return how many batches have been added but not committed yet.
size_t uncommited_batches() const {
return uncommited_batches_;
}
private:
void thread_fn();
uint64_t batch_{0};
size_t uncommited_batches_{0};
// Cuda stream and event for signaling kernel completion.
CudaStream signal_stream_;
CudaEvent signal_event_;
// Worker thread.
SharedEvent worker_event_;
std::thread worker_;
std::mutex worker_mutex_;
bool stop_{false};
// Tasks are put in |pending_tasks_| first, and then moved to
// |worker_tasks_| when end_batch() is called.
using Tasks = std::vector<std::function<void()>>;
Tasks pending_tasks_;
std::map<uint64_t, Tasks> worker_tasks_;
};
} // namespace mlx::core::cu

View File

@ -0,0 +1,5 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)

49
mlx/backend/gpu/copy.cpp Normal file
View File

@ -0,0 +1,49 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
} // namespace mlx::core

View File

@ -5,6 +5,8 @@
#include "mlx/backend/common/copy.h"
#include "mlx/stream.h"
#include <optional>
namespace mlx::core {
// Generic copy inplace

View File

@ -0,0 +1,217 @@
// Copyright © 2025 Apple Inc.
#include "mlx/primitives.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include <cassert>
#define MLX_PROFILER_RANGE(message)
namespace mlx::core {
namespace {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
eval(inputs, out);
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsType::eval_gpu");
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Copy::eval_gpu");
eval(inputs, out);
}
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Depends::eval_gpu");
eval(inputs, outputs);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
eval(inputs, out);
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Full::eval_gpu");
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Flatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
eval(inputs, out);
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Reshape::eval_gpu");
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Split::eval_gpu");
eval(inputs, outputs);
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Slice::eval_gpu");
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
slice_gpu(in, out, start_indices_, strides_, stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
eval(inputs, out);
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Transpose::eval_gpu");
eval(inputs, out);
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("View::eval_gpu");
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@ -0,0 +1,44 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
slice(in, out, start_indices, strides);
}
void pad_gpu(
const array& in,
const array& val,
array& out,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core

View File

@ -90,7 +90,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = 1;
work_per_thread = get_work_per_thread(a.dtype());
}
std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
@ -137,13 +137,20 @@ void binary_op_gpu_inplace(
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@ -64,6 +64,7 @@ inline void build_kernel(
cnt++);
}
std::string idx_type = use_big_index ? "int64_t" : "uint";
if (add_indices) {
os += fmt::format(
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
@ -83,6 +84,9 @@ inline void build_kernel(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
} else {
os += fmt::format(
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
}
if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
@ -92,13 +96,14 @@ inline void build_kernel(
os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n";
std::string idx_type = use_big_index ? "int64_t" : "uint";
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
} else if (contiguous) {
os += " uint index = N_ * pos.x;\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
@ -110,6 +115,9 @@ inline void build_kernel(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
}
if (work_per_thread > 1 && contiguous) {
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
}
// Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs;
@ -193,7 +201,7 @@ inline void build_kernel(
}
// Open per-thread loop
if (work_per_thread > 1) {
if (work_per_thread > 1 && !contiguous) {
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
@ -272,6 +280,7 @@ void Compiled::eval_gpu(
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() {
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
std::string kernel = metal::utils();
concatenate(
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
@ -284,7 +293,9 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false);
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_large",
@ -295,7 +306,8 @@ void Compiled::eval_gpu(
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ true);
/* use_big_index = */ true,
/* work_per_thread = */ work_per_thread);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
@ -468,6 +480,13 @@ void Compiled::eval_gpu(
if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++);
} else {
auto size = outputs[0].data_size();
if (large) {
compute_encoder.set_bytes<int64_t>(size, cnt++);
} else {
compute_encoder.set_bytes<int>(size, cnt++);
}
}
// Put the number of dims in if it is dynamic
@ -477,12 +496,13 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
size_t nthreads = outputs[0].data_size();
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = large
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
? get_2d_grid_dims(
outputs[0].shape(), outputs[0].strides(), work_per_thread)
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {

View File

@ -1,11 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu(
/*copies = */ copies);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void slow_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 16, bn = 8;
int tm = 4, tn = 4;
std::ostringstream kname;
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
<< "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
size_t grid_dim_z = conv_params.N;
MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
const Stream& s,
metal::Device& d,
@ -771,6 +693,141 @@ void depthwise_conv_2D_gpu(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void dispatch_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies) {
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (is_idil_one && conv_params.groups > 1) {
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip,
std::vector<array>& copies) {
bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2);
int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups;
const int O_per_group = wt.shape(0) / groups;
// Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
MLXConvParams<2> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ C,
/* const int O = */ O,
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
/* const int str[NDIM] = */ {wt_strides[0], 1},
/* const int pad[NDIM] = */ {padding[0], 0},
/* const int kdil[NDIM] = */ {wt_dilation[0], 1},
/* const int idil[NDIM] = */ {in_dilation[0], 1},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], 0, in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], 0, out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
return;
}
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
@ -808,57 +865,7 @@ void conv_2D_gpu(
/* const int groups = */ groups,
/* const bool flip = */ flip,
};
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
void conv_3D_gpu(
@ -952,7 +959,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@ -967,7 +974,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@ -983,12 +990,13 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in,
wt,
out,
padding_,
padding_lo_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_);
flip_,
copies);
}
// Throw error
else {

View File

@ -1,35 +1,15 @@
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
@ -104,6 +84,8 @@ void copy_gpu_inplace(
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
}
}
} else {
work_per_thread = get_work_per_thread(in.dtype());
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@ -165,39 +147,23 @@ void copy_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
@ -214,14 +180,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = out.data_size();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

View File

@ -1,6 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"

View File

@ -95,6 +95,10 @@ struct CommandEncoder {
return enc_->setBytes(&v, sizeof(T), idx);
}
void set_threadgroup_memory_length(size_t length, int idx) {
enc_->setThreadgroupMemoryLength(length, idx);
}
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}

View File

@ -4,7 +4,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/distributed/ops.h"

View File

@ -141,7 +141,7 @@ void Fence::update(Stream stream, const array& x) {
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Barrier on previous kernels
compute_encoder.barrier();

View File

@ -7,10 +7,10 @@
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
@ -632,7 +632,7 @@ void fft_op(
func_consts.push_back(make_int(&rader_m, 3));
// The overall number of FFTs we're going to compute for this input
int size = out.dtype() == float32 ? out.size() : in.size();
size_t size = out.dtype() == float32 ? out.size() : in.size();
if (real && inverse && four_step_params.required) {
size = out.size();
}
@ -659,8 +659,6 @@ void fft_op(
// We can perform 2 RFFTs at once so the batch size is halved.
batch_size = (batch_size + 2 - 1) / 2;
}
int out_buffer_size = out.size();
auto& compute_encoder = d.get_command_encoder(s.index);
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
auto out_type_str = out.dtype() == float32 ? "float" : "float2";

View File

@ -1,11 +1,9 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
@ -15,7 +13,6 @@
namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M
@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) {
return source.str();
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
void hadamard_mn_contiguous(
const array& x,
array& y,
int m,
int n1,
int n2,
float scale,
metal::Device& d,
const Stream& s) {
int n = n1 * n2;
int read_width_n1 = n1 == 2 ? 2 : 4;
int read_width_n2 = n2 == 2 ? 2 : 4;
int read_width_m = (n == 2 || m == 28) ? 2 : 4;
int max_radix_1 = std::min(n1, 16);
int max_radix_2 = std::min(n2, 16);
float scale_n1 = 1.0;
float scale_n2 = (m == 1) ? scale : 1.0;
float scale_m = scale;
auto& in = inputs[0];
// n2 is a row contiguous power of 2 hadamard transform
MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);
MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);
std::vector<array> copies;
// Only support the last axis for now
int axis = in.ndim() - 1;
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint
bool no_copy = x.flags().row_contiguous;
if (no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
// n1 is a strided power of 2 hadamard transform with stride n2
MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);
MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);
// m is a strided hadamard transform with stride n = n1 * n2
MTL::Size group_dims_m(
std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);
MTL::Size grid_dims_m(
group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);
// Make the kernel
std::string kname;
kname.reserve(32);
concatenate(kname, "hadamard_", n * m, "_", type_to_name(x));
auto lib = d.get_library(kname, [&]() {
std::string kernel;
concatenate(
kernel,
metal::utils(),
gen_hadamard_codelet(m),
metal::hadamard(),
get_template_definition(
"n2" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n2,
max_radix_2,
read_width_n2));
if (n1 > 1) {
kernel += get_template_definition(
"n1" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n1,
max_radix_1,
read_width_n1,
n2);
}
};
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.copy_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
int n, m;
std::tie(n, m) = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
return kernel_source.str();
if (m > 1) {
kernel += get_template_definition(
"m" + kname,
"hadamard_m",
get_type_string(x.dtype()),
n,
m,
read_width_m);
}
return kernel;
});
int batch_size = in.size() / n;
int threads_per = n / max_radix;
auto& compute_encoder = d.get_command_encoder(s.index);
auto launch_hadamard = [&](const array& in,
array& out,
const std::string& kernel_name,
float scale) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
// Launch the strided transform for n1
if (n1 > 1) {
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel("n1" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(scale, 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
};
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(temp, out, "m" + kernel_name, scale_);
} else {
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_n1, 2);
compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);
}
d.add_temporaries(std::move(copies), s.index);
// Launch the transform for n2
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel("n2" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(n1 > 1 ? y : x, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_n2, 2);
compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);
// Launch the strided transform for m
if (m > 1) {
auto kernel = d.get_kernel("m" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(y, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_m, 2);
compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);
}
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
// Split the hadamard transform so that all of them work on vectors smaller
// than 8192 elements.
//
// We decompose it in the following way:
//
// n = m * n1 * n2 = m * 2^k1 * 2^k2
//
// where m is in (1, 12, 20, 28) and n1 and n2 <= 8192
auto [n, m] = decompose_hadamard(in.shape().back());
int n1 = 1, n2 = n;
if (n > 8192) {
for (n2 = 2; n2 * n2 < n; n2 *= 2) {
}
n1 = n / n2;
}
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);
} else {
copy_gpu(in, out, CopyType::General, s);
hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);
}
}
} // namespace mlx::core

View File

@ -2,7 +2,8 @@
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"
@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2);
// Set source info
auto shape = idx.shape();
shape.erase(shape.begin() + axis_);
compute_encoder.set_vector_bytes(shape, 3);
auto strides = src.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 4);
strides = idx.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4);
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(src.shape(axis_), 8);
@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2);
// Set source info
auto shape = idx.shape();
shape.erase(shape.begin() + axis_);
compute_encoder.set_vector_bytes(shape, 3);
auto strides = upd.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 4);
strides = idx.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(out.shape(axis_), 8);

View File

@ -80,9 +80,10 @@ template <typename T, typename Op, int N_READS = 4>
const constant size_t& ndim [[buffer(5)]],
const constant int64_t& axis_stride [[buffer(6)]],
const constant size_t& axis_size [[buffer(7)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 gsize [[threads_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
@ -104,17 +105,18 @@ template <typename T, typename Op, int N_READS = 4>
// Compute the input/output index. There is one beginning and one output for
// the whole threadgroup.
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
int64_t row_idx = gid.y + static_cast<int64_t>(gsize.y) * gid.z;
auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim);
auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim);
IndexValPair<T> best{0, Op::init};
threadgroup IndexValPair<T> local_data[32];
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {
// Read the current value
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS;
uint32_t offset = current_index;
const device T* current_in = in + in_idx + current_index * axis_stride;
T vals[N_READS];
@ -144,7 +146,7 @@ template <typename T, typename Op, int N_READS = 4>
}
// Read the appropriate value from local data and perform one simd reduction
uint simd_groups = ceildiv(lsize, simd_size);
uint simd_groups = ceildiv(lsize.x, simd_size);
if (simd_lane_id < simd_groups) {
best = local_data[simd_lane_id];
}
@ -154,7 +156,7 @@ template <typename T, typename Op, int N_READS = 4>
}
// Finally write the output
if (lid == 0) {
if (lid.x == 0) {
out[out_idx] = best.index;
}
}

View File

@ -9,64 +9,85 @@ template <typename T, typename U, typename Op>
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[0], b[offset]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[0]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@ -12,82 +12,103 @@ template <typename T, typename U, typename Op>
d[index] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[index]);
c[index] = out[0];
d[index] = out[1];
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[0]);
c[index] = out[0];
d[index] = out[1];
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[index]);
c[index] = out[0];
d[index] = out[1];
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@ -1,39 +1,53 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U>
template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
}
template <typename T, typename U>
template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
}
template <typename T, typename U>
template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_s2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[0]);
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
}
template <typename T, typename U>
template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_v2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
}
template <typename T, typename U, typename IdxT = int64_t>

View File

@ -98,7 +98,7 @@ struct ReadWriter {
}
METAL_FUNC void load() const {
int batch_idx = elem.x * grid.y * n;
size_t batch_idx = size_t(elem.x * grid.y) * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
@ -121,7 +121,7 @@ struct ReadWriter {
}
METAL_FUNC void write() const {
int batch_idx = elem.x * grid.y * n;
size_t batch_idx = size_t(elem.x * grid.y) * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
@ -144,7 +144,7 @@ struct ReadWriter {
// Padded IO for Bluestein's algorithm
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
@ -161,7 +161,7 @@ struct ReadWriter {
}
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
float2 inv_factor = {1.0f / n, -1.0f / n};
@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
template <>
METAL_FUNC void ReadWriter<float, float2>::load() const {
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@ -283,7 +283,8 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::write() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
int grid_index = elem.x * grid.y + elem.y;
@ -317,7 +318,7 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::load_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
int length,
const device float2* w_k) const {
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;
@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
template <>
METAL_FUNC void ReadWriter<float2, float>::load() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
int n_over_2 = (n / 2) + 1;
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@ -503,7 +505,7 @@ template <>
METAL_FUNC void ReadWriter<float2, float>::write_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;

View File

@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) {
}
}
template <typename T, int N, int max_radix, int read_width>
template <typename T, int N, int max_radix, int read_width, int stride = 1>
[[kernel]] void hadamard_n(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@ -46,18 +46,25 @@ template <typename T, int N, int max_radix, int read_width>
constexpr short logFinal = logN % logR;
constexpr short final_radix = 1 << (logFinal);
int batch_idx = elem.x * N;
short i = elem.y;
int batch_idx = elem.y * N * stride + elem.z;
short i = elem.x;
threadgroup T buf[N];
// Read values from device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
if (stride == 1) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix; j++) {
buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];
}
}
@ -113,12 +120,20 @@ template <typename T, int N, int max_radix, int read_width>
}
// Write values to device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
if (stride == 1) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = T(buf[index + r] * scale);
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = T(buf[index + r] * scale);
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix; j++) {
out[batch_idx + (j * num_threads + i) * stride] =
buf[j * num_threads + i];
}
}
}

View File

@ -103,8 +103,8 @@ template <typename T, typename AccT = float, int N_READS = 4>
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
vals[i] =
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
}
}
prevmax = maxval;
@ -134,10 +134,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
if (lid == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}

View File

@ -224,7 +224,7 @@ template <
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
// Simple loop over non_row_reductions and reduce the row in the thread.
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
IdxT out_idx = tid.x + tsize.x * IdxT(tid.y);
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_row_reductions; r++) {

View File

@ -56,9 +56,9 @@ template <typename T, int D, int V = D>
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread;
@ -213,9 +213,9 @@ template <typename T, int D, int V = D>
const int block_idx = tid.z;
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
const int kv_head_idx = head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread;
@ -358,8 +358,8 @@ template <typename T, int D>
// Adjust positions
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int n_heads = tpg.x;
const int q_offset = n_heads * q_seq_idx + head_idx;
const int q_offset = head_idx * tpg.y + q_seq_idx;
;
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += q_offset * blocks;
maxs += q_offset * blocks;

View File

@ -128,8 +128,8 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
vals[i] =
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
}
}
prevmax = maxval;

View File

@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader {
const constant MLXConvParams<2>* params;
int weight_hw;
int weight_step;
const int read_n;
const bool do_read;
@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader {
src(src_ + bi * src_ld + bj),
params(params_),
weight_hw(0),
weight_step(params->C / params->groups),
read_n(offsets.y + bi),
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader {
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_hw < (params->wS[1] * params->wS[0])) {
src += params->wt_strides[2];
src += weight_step;
return;
}
weight_hw = 0;
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
}
};
} // namespace steel
} // namespace mlx
} // namespace mlx

View File

@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
return;
}
const device T* curr_src = src + weight_hw * params->wt_strides[2];
const device T* curr_src = src + weight_hw * (params->C / params->groups);
if (BN != 8 || do_read) {
STEEL_PRAGMA_UNROLL
@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels {
};
} // namespace steel
} // namespace mlx
} // namespace mlx

View File

@ -1,25 +1,32 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename Op>
template <typename T, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
d[index] = Op()(a[index], b[index], c[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
}
template <typename T, typename Op>
template <typename T, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
d[offset] = Op()(a[offset], b[offset], c[offset]);
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
}
template <typename T, typename Op, typename IdxT = int64_t>

View File

@ -1,21 +1,28 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void unary_v(
device const T* in,
device U* out,
constant uint& size,
uint index [[thread_position_in_grid]]) {
out[index] = Op()(in[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
out[index + i] = Op()(in[index + i]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void unary_v2(
device const T* in,
device U* out,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
out[offset] = Op()(in[offset]);
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
}
template <

View File

@ -15,6 +15,14 @@
typedef half float16_t;
// Work per thread values for different types. The values here are expected to
// match get_work_per_thread in mlx/backend/metal/utils.h
template <typename U>
struct WorkPerThread {
static_assert(sizeof(U) <= 8, "Type too large");
static constexpr int constant n = 8 / sizeof(U);
};
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"

View File

@ -7,7 +7,7 @@
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
@ -716,6 +716,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
// Return 0s if either input is empty
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return;
}
// Copy c into out and return
if (inputs[0].shape(-1) == 0) {
copy_gpu(
inputs[2],
out,
inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);

View File

@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/reduce.h"

View File

@ -7,10 +7,10 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(step, 1);
}
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
static array compute_dynamic_offset(
const array& indices,
const Strides& strides,
@ -201,8 +182,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
auto gd = get_2d_grid_dims(out.shape(), out.strides());
MTL::Size grid_dims = MTL::Size(thread_group_size, gd.width, gd.height);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
slice_gpu(in, out, start_indices_, strides_, stream());
}
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
@ -492,18 +357,6 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const Stream& s = */ stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void QRF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@ -525,10 +378,16 @@ void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
}
void Eig::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI.");
}
void Eigh::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI.");
}
void LUF::eval_gpu(
@ -537,35 +396,4 @@ void LUF::eval_gpu(
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@ -4,7 +4,7 @@
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"

View File

@ -3,7 +3,7 @@
#include <algorithm>
#include <cassert>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"

View File

@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"

View File

@ -2,7 +2,7 @@
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
@ -154,9 +154,9 @@ void sdpa_vector(
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.strides()[1];
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.strides()[1];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(1024, 1, 1);
@ -199,11 +199,10 @@ void sdpa_vector(
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11 + float_mask);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
int32_t head_stride =
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15);
@ -238,9 +237,10 @@ void sdpa_vector_2pass(
int N = k.shape(2);
int blocks = 32;
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.strides()[1];
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.strides()[1];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(8 * 32, 1, 1);
MTL::Size grid_dims(B, q.shape(2), blocks);
@ -302,11 +302,10 @@ void sdpa_vector_2pass(
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 13 + float_mask);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
int32_t head_stride =
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17);
@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu(
}
};
// Checks if arr is row contiguous or the sequence and head dimension are
// transposed
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
};
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query
if (q_pre.shape(2) <= 8) {
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
auto q_copy_unless = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
if (shape[0] == 1 || shape[1] == 1) {
// If either the batch or head dimension is a singleton, the other can
// be transposed with the sequence dimension
auto bidx = shape[0] == 1 ? 1 : 0;
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
(strides[bidx] == shape[3]);
}
return false;
};
auto kv_copy_unless = [](const array& arr) {
// keys and values should be copied if:
// - the last dimension is not contiguous
// - the batch and head dim are not contiguous
auto& strides = arr.strides();
auto& shape = arr.shape();
if (strides.back() != 1) {
return false;
}
if (shape[0] == 1 || shape[1] == 1) {
return true;
}
return (strides[0] == strides[1] * shape[1]);
};
const auto& q = copy_unless(q_copy_unless, q_pre);
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
// Donate the query if possible
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
q.size() == o.size()) {
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
} else {
if (o.shape(2) == 1) {
o.set_data(allocator::malloc(o.nbytes()));
} else {
auto strides = o.strides();
strides[2] = o.shape(1) * o.shape(3);
strides[1] = o.shape(3);
auto flags = q.flags();
flags.row_contiguous = q.shape(1) == 1;
o.set_data(
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
}
o.set_data(allocator::malloc(o.nbytes()));
}
auto mask =
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
auto mask_copy_unless = [&q](const array& arr) {
auto& strides = arr.strides();
auto& shape = arr.shape();
return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 ||
(strides[0] == strides[1] * shape[1]);
};
auto mask = inputs.size() > 3
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
: std::nullopt;
// We route to the 2 pass fused attention if
// - The device is large and the sequence length long

View File

@ -3,7 +3,7 @@
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"

View File

@ -2,21 +2,12 @@
#include <numeric>
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
slice(in, out, start_indices, strides);
}
void concatenate_gpu(
const std::vector<array>& inputs,
array& out,
@ -48,30 +39,4 @@ void concatenate_gpu(
}
}
void pad_gpu(
const array& in,
const array& val,
array& out,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"

View File

@ -2,7 +2,7 @@
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"

View File

@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > INT32_MAX;
work_per_thread = 1;
work_per_thread = get_work_per_thread(b.dtype());
}
std::string kernel_name;
if (topt == TernaryOpType::General) {
@ -106,13 +106,19 @@ void ternary_op_gpu_inplace(
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 4);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 4);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@ -34,18 +34,19 @@ void unary_op_gpu_inplace(
};
auto [shape, strides] = maybe_collapse();
int ndim = shape.size();
size_t nthreads = contig ? in.data_size() : in.size();
bool large;
if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
int work_per_thread = !contig && large ? 4 : 1;
int work_per_thread;
std::string kernel_name;
if (contig) {
work_per_thread = get_work_per_thread(in.dtype());
kernel_name = (large ? "v2" : "v");
} else {
work_per_thread = large ? 4 : 1;
kernel_name = "gn" + std::to_string(work_per_thread);
if (large) {
kernel_name += "large";
@ -75,12 +76,20 @@ void unary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
size_t nthreads = ceildiv(in.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(in.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(in.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) {
concatenate(acc, args...);
}
inline int get_work_per_thread(Dtype dtype) {
return std::max(1, 8 / dtype.size());
}
inline size_t ceildiv(size_t n, size_t m) {
return (n + m - 1) / m;
}
} // namespace mlx::core

View File

@ -55,6 +55,7 @@ NO_CPU(DynamicSlice)
NO_CPU(DynamicSliceUpdate)
NO_CPU(NumberOfElements)
NO_CPU(Remainder)
NO_CPU_MULTI(Eig)
NO_CPU_MULTI(Eigh)
NO_CPU(Equal)
NO_CPU(Erf)

View File

@ -126,6 +126,7 @@ NO_GPU(Unflatten)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)
NO_GPU_MULTI(Eig)
NO_GPU(View)
namespace fast {

View File

@ -168,6 +168,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) {
parent.first.inputs()[parent.second] = dst;
pairs.push_back(parent);
}
// If src is a parent of dst, remove it from dst's parents
for (auto it = pairs.begin(); it != pairs.end();) {
if (it->first.id() == src.id()) {
it = pairs.erase(it);
} else {
it++;
}
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}

View File

@ -331,6 +331,7 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(SVD),
SERIALIZE_PRIMITIVE(Inverse),
SERIALIZE_PRIMITIVE(Cholesky),
SERIALIZE_PRIMITIVE(Eig),
SERIALIZE_PRIMITIVE(Eigh),
SERIALIZE_PRIMITIVE(AffineQuantize),
SERIALIZE_PRIMITIVE(RMSNorm),
@ -470,6 +471,9 @@ bool FunctionTable::match(
if (x.dtype() != y.dtype()) {
return false;
}
if (x.ndim() != y.ndim()) {
return false;
}
if (!shapeless && x.shape() != y.shape()) {
return false;
}

View File

@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) {
}
}
void check_float_or_complex(Dtype dtype, const std::string& prefix) {
if (dtype != float32 && dtype != float64 && dtype != complex64) {
std::ostringstream msg;
msg << prefix << " Arrays must have type float32, float64 or complex64. "
<< "Received array with type " << dtype << ".";
throw std::invalid_argument(msg.str());
}
}
Dtype at_least_float(const Dtype& d) {
return issubdtype(d, inexact) ? d : promote_types(d, float32);
}
@ -488,12 +497,12 @@ array cross(
return concatenate(outputs, axis, s);
}
void validate_eigh(
void validate_eig(
const array& a,
const StreamOrDevice& stream,
const std::string fname) {
check_cpu_stream(stream, fname);
check_float(a.dtype(), fname);
check_float_or_complex(a.dtype(), fname);
if (a.ndim() < 2) {
std::ostringstream msg;
@ -511,11 +520,12 @@ array eigvalsh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, s, "[linalg::eigvalsh]");
validate_eig(a, s, "[linalg::eigvalsh]");
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
return array(
std::move(out_shape),
a.dtype(),
eigval_type,
std::make_shared<Eigh>(to_stream(s), UPLO, false),
{a});
}
@ -524,15 +534,36 @@ std::pair<array, array> eigh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, s, "[linalg::eigh]");
validate_eig(a, s, "[linalg::eigh]");
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
auto out = array::make_arrays(
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()},
{eigval_type, a.dtype()},
std::make_shared<Eigh>(to_stream(s), UPLO, true),
{a});
return std::make_pair(out[0], out[1]);
}
array eigvals(const array& a, StreamOrDevice s /* = {} */) {
validate_eig(a, s, "[linalg::eigvals]");
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
return array(
std::move(out_shape),
complex64,
std::make_shared<Eig>(to_stream(s), false),
{a});
}
std::pair<array, array> eig(const array& a, StreamOrDevice s /* = {} */) {
validate_eig(a, s, "[linalg::eig]");
auto out = array::make_arrays(
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{complex64, complex64},
std::make_shared<Eig>(to_stream(s), true),
{a});
return std::make_pair(out[0], out[1]);
}
void validate_lu(
const array& a,
const StreamOrDevice& stream,

View File

@ -99,6 +99,10 @@ array cross(
int axis = -1,
StreamOrDevice s = {});
std::pair<array, array> eig(const array& a, StreamOrDevice s = {});
array eigvals(const array& a, StreamOrDevice s = {});
array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
std::pair<array, array>

View File

@ -472,9 +472,24 @@ array hadamard_transform(
const array& a,
std::optional<float> scale_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
if (a.size() == 0) {
throw std::invalid_argument(
"[hadamard_transform] Does not support empty arrays.");
}
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1));
int n = a.ndim() > 0 ? a.shape(-1) : 1;
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n);
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
// Nothing to do for a scalar
if (n == 1) {
if (scale == 1) {
return a;
}
return multiply(a, array(scale, dtype), s);
}
return array(
a.shape(),
dtype,
@ -3160,6 +3175,10 @@ array scatter_axis(
throw std::invalid_argument(msg.str());
}
if (a.size() == 0) {
return a;
}
auto upd = astype(values, a.dtype(), s);
// Squeeze leading singletons out of update
@ -3565,21 +3584,21 @@ Shape conv_out_shape(
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (kernel_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for "
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (input_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid input dilation " << input_dilation << "for "
msg << "[conv] Invalid input dilation " << input_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
@ -3963,6 +3982,7 @@ array conv_general(
to_stream(s),
stride,
padding_lo,
padding_hi,
kernel_dilation,
input_dilation,
groups,
@ -4314,6 +4334,10 @@ array addmm(
c = reshape(c, c_reshape, s);
}
if (c.shape() != out_shape) {
throw std::invalid_argument(
"[addmm] input c must broadcast to the output shape");
}
auto out = array(
std::move(out_shape),

View File

@ -875,6 +875,43 @@ std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
}
std::pair<std::vector<array>, std::vector<int>> Eig::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
bool needs_move = axes[0] >= (inputs[0].ndim() - 2);
auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
auto ax = needs_move ? 0 : axes[0];
std::vector<array> outputs;
if (compute_eigenvectors_) {
auto [values, vectors] = linalg::eig(a, stream());
outputs = {values, vectors};
} else {
outputs = {linalg::eigvals(a, stream())};
}
return {outputs, std::vector<int>(outputs.size(), ax)};
}
std::vector<Shape> Eig::output_shapes(const std::vector<array>& inputs) {
auto shape = inputs[0].shape();
shape.pop_back(); // Remove last dimension for eigenvalues
if (compute_eigenvectors_) {
return {
std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors
} else {
return {std::move(shape)}; // Only eigenvalues
}
}
bool Eig::is_equivalent(const Primitive& other) const {
auto& e_other = static_cast<const Eig&>(other);
return compute_eigenvectors_ == e_other.compute_eigenvectors_;
}
std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@ -1055,7 +1092,8 @@ array conv_weight_backward_patches(
const array& wt,
const array& cotan,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
StreamOrDevice s) {
// Resolve Padded input shapes and strides
Shape padding_starts(in.ndim(), 0);
@ -1064,9 +1102,9 @@ array conv_weight_backward_patches(
// padded shape
for (int i = 1; i < in.ndim() - 1; i++) {
in_padded_shape[i] += 2 * padding[i - 1];
padding_ends[i] += padding[i - 1];
padding_starts[i] += padding[i - 1];
in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1];
padding_ends[i] += padding_lo[i - 1];
padding_starts[i] += padding_lo[i - 1];
}
// padded strides (contiguous)
@ -1078,9 +1116,14 @@ array conv_weight_backward_patches(
// Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1);
Shape padding_(padding.begin(), padding.end());
auto in_padded = pad(
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s);
auto in_padded =
pad(in,
padded_axes,
Shape(padding_lo),
Shape(padding_hi),
array(0, in.dtype()),
"constant",
s);
// Resolve strided patches
@ -1147,16 +1190,16 @@ std::vector<array> Convolution::vjp(
for (int a : argnums) {
// Grads for input
if (a == 0) {
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
std::vector<int> padding_lo = padding_lo_;
std::vector<int> padding_hi = padding_hi_;
for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_[i] - 1;
padding_lo[i] = wt_size - padding_lo_[i] - 1;
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding_[i];
padding_hi[i] = in_size - out_size + padding_hi_[i];
}
// Check for negative padding
@ -1226,18 +1269,18 @@ std::vector<array> Convolution::vjp(
if (no_dilation && !flip_ && groups_ == 1) {
auto grad = conv_weight_backward_patches(
in, wt, cotan, kernel_strides_, padding_, stream());
in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());
grads.push_back(grad);
} else {
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
auto padding_hi = padding_lo_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
padding_hi[i] = out_size - in_size + wt_size - padding_hi[i] - 1;
}
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1);
@ -1245,7 +1288,7 @@ std::vector<array> Convolution::vjp(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_lo = */ padding_lo_,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,
@ -1283,7 +1326,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
in,
w,
kernel_strides_,
padding_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_,
groups,
@ -1332,7 +1376,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
bool Convolution::is_equivalent(const Primitive& other) const {
const Convolution& c_other = static_cast<const Convolution&>(other);
return padding_ == c_other.padding_ &&
return padding_lo_ == c_other.padding_lo_ &&
padding_hi_ == c_other.padding_hi_ &&
kernel_strides_ == c_other.kernel_strides_ &&
kernel_dilation_ == c_other.kernel_dilation_ &&
input_dilation_ == c_other.input_dilation_ &&
@ -1484,14 +1529,16 @@ std::vector<array> Divide::vjp(
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
array denominator_bar = conjugate(primals[1], stream());
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(divide(cotangents[0], primals[1], stream()));
vjps.push_back(divide(cotangents[0], denominator_bar, stream()));
} else {
vjps.push_back(negative(
divide(
multiply(cotangents[0], primals[0], stream()),
square(primals[1], stream()),
multiply(
cotangents[0], conjugate(primals[0], stream()), stream()),
square(denominator_bar, stream()),
stream()),
stream()));
}
@ -1946,30 +1993,74 @@ std::vector<array> FFT::vjp(
assert(argnums.size() == 1);
auto& in = primals[0];
std::vector<int> axes(axes_.begin(), axes_.end());
// TODO: Add it as an option to do an unnormalized or scaled fft so that this
// isn't part of the graph.
double n_elements = 1;
for (auto ax : axes) {
n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax);
}
if (real_ && inverse_) {
auto out = fft::fftn(cotangents[0], axes, stream());
auto start = Shape(out.ndim(), 0);
auto stop = in.shape();
out = slice(out, start, stop, stream());
auto mask_shape = out.shape();
mask_shape[axes_.back()] -= 2;
auto mask = full(mask_shape, 2.0f, stream());
auto pad_shape = out.shape();
pad_shape[axes_.back()] = 1;
auto pad = full(pad_shape, 1.0f, stream());
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
return {multiply(mask, out, stream())};
// Make a mask to account for the double use in the forward pass.
// Everything except the DC and nyquist frequencies gets doubled.
int N = in.shape(axes_.back());
bool odd = cotangents[0].shape(axes_.back()) % 2;
Shape c(in.ndim(), 1);
c[axes_.back()] = N;
array indices = reshape(arange(N, stream()), std::move(c), stream());
array first(0, indices.dtype());
array last(N - 1 + odd, indices.dtype());
array one(1 / n_elements, in.dtype());
array two(2 / n_elements, in.dtype());
array mask = where(
logical_and(
greater(indices, first, stream()),
less(indices, last, stream()),
stream()),
two,
one,
stream());
return {
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
} else if (real_) {
Shape n;
for (auto ax : axes_) {
n.push_back(in.shape()[ax]);
n.push_back(in.shape(ax));
}
return {astype(
fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())};
// Make a mask to account for the double use in the forward pass.
// Everything except the DC and nyquist frequencies gets halved.
int N = cotangents[0].shape(axes_.back());
bool odd = in.shape(axes_.back()) % 2;
Shape c(in.ndim(), 1);
c[axes_.back()] = N;
array indices = reshape(arange(N, stream()), std::move(c), stream());
array first(0, indices.dtype());
array last(N - 1 + odd, indices.dtype());
array one(1, complex64);
array half(0.5, complex64);
array mask = where(
logical_and(
greater(indices, first, stream()),
less(indices, last, stream()),
stream()),
half,
one,
stream());
return {multiply(
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
array(n_elements, in.dtype()),
stream())};
} else if (inverse_) {
return {fft::ifftn(cotangents[0], axes, stream())};
return {multiply(
fft::fftn(cotangents[0], axes, stream()),
array(1 / n_elements, complex64),
stream())};
} else {
return {fft::fftn(cotangents[0], axes, stream())};
return {multiply(
fft::ifftn(cotangents[0], axes, stream()),
array(n_elements, complex64),
stream())};
}
}
@ -2772,7 +2863,8 @@ std::vector<array> Multiply::vjp(
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
vjps.push_back(multiply(
conjugate(primals[1 - arg], stream()), cotangents[0], stream()));
}
return vjps;
}
@ -3456,7 +3548,7 @@ std::vector<array> Reduce::vjp(
}
else {
throw std::runtime_error("Reduce type VJP not yet implemented.");
return {zeros_like(in, stream())};
}
}

View File

@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive {
explicit Convolution(
Stream stream,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
const int groups = 1,
const bool flip = false)
: UnaryPrimitive(stream),
padding_(padding),
padding_lo_(padding_lo),
padding_hi_(padding_hi),
kernel_strides_(kernel_strides),
kernel_dilation_(kernel_dilation),
input_dilation_(input_dilation),
@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(
padding_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive {
}
private:
std::vector<int> padding_;
std::vector<int> padding_lo_;
std::vector<int> padding_hi_;
std::vector<int> kernel_strides_;
std::vector<int> kernel_dilation_;
std::vector<int> input_dilation_;
@ -2377,6 +2381,29 @@ class Cholesky : public UnaryPrimitive {
bool upper_;
};
class Eig : public Primitive {
public:
explicit Eig(Stream stream, bool compute_eigenvectors)
: Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_PRINT(Eig)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return compute_eigenvectors_;
}
private:
bool compute_eigenvectors_;
};
class Eigh : public Primitive {
public:
explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)

View File

@ -176,24 +176,51 @@ array uniform(
array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));
}
inline array complex_normal(
Shape shape,
const std::optional<array>& loc,
const std::optional<array>& scale,
const std::optional<array>& key,
StreamOrDevice s) {
auto stream = to_stream(s);
auto low = above_minus_one_with_default(float32);
auto high = array(1.0f, float32);
shape.push_back(2);
auto samples =
erfinv(uniform(low, high, shape, float32, key, stream), stream);
samples = squeeze(view(samples, complex64, stream), -1, stream);
if (scale.has_value()) {
samples = multiply(*scale, samples, stream);
}
if (loc.has_value()) {
samples = add(*loc, samples, stream);
}
return samples;
}
array normal(
const Shape& shape,
Dtype dtype,
const float loc /* = 0.0 */,
const float scale /* = 1.0 */,
const std::optional<array>& key /*= nullopt */,
const std::optional<array>& loc,
const std::optional<array>& scale,
const std::optional<array>& key,
StreamOrDevice s /* = {} */) {
if (dtype == complex64) {
return complex_normal(shape, loc, scale, key, s);
}
auto stream = to_stream(s);
auto low = above_minus_one_with_default(dtype);
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
samples =
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
if (scale != 1.0) {
samples = multiply(array(scale, dtype), samples, stream);
auto applied_scale = array(std::sqrt(2.0), dtype);
if (scale.has_value()) {
applied_scale =
multiply(applied_scale, astype(*scale, dtype, stream), stream);
}
if (loc != 0.0) {
samples = add(array(loc, dtype), samples, stream);
samples = multiply(applied_scale, erfinv(samples, stream), stream);
if (loc.has_value()) {
samples = add(astype(*loc, dtype, stream), samples, stream);
}
return samples;
}

View File

@ -94,12 +94,24 @@ inline array uniform(
/** Generate samples from the standard normal distribution. */
array normal(
const Shape& shape,
Dtype dtype,
const std::optional<array>& loc,
const std::optional<array>& scale,
const std::optional<array>& key,
StreamOrDevice s = {});
inline array normal(
const Shape& shape,
Dtype dtype,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
StreamOrDevice s = {}) {
auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));
auto scale_ =
scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));
return normal(shape, dtype, loc_, scale_, key, s);
}
inline array normal(
const Shape& shape,
const float loc,
@ -113,13 +125,13 @@ inline array normal(
const Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, dtype, 0.0, 1.0, key, s);
return normal(shape, dtype, std::nullopt, std::nullopt, key, s);
}
inline array normal(
const Shape& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, 0.0, 1.0, key, s);
return normal(shape, float32, std::nullopt, std::nullopt, key, s);
}
/** Generate samples from a multivariate normal distribution. **/

View File

@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 25
#define MLX_VERSION_PATCH 1
#define MLX_VERSION_PATCH 2
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command):
# Repeat the stdout and stderr to the local machine
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno()]
to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()]
pidfile = ""
stdin_buffer = b""
stdout_buffer = b""
stderr_buffer = b""
while p.poll() is None:
try:
stdin_buffer += input_queue.get_nowait()
@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command):
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
is_stdout = fd == p.stdout.fileno()
outfile = sys.stdout if is_stdout else sys.stderr
msg = os.read(fd, 8192).decode(errors="ignore")
# Fetch the PID file first if we haven't already
@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command):
pidfile, *msg = msg.split("\n", maxsplit=1)
msg = msg[0] if msg else ""
outfile.write(msg)
outfile.flush()
is_stdout = fd == p.stdout.fileno()
if is_stdout:
stdout_buffer += msg.encode()
else:
stderr_buffer += msg.encode()
for fd in wlist:
if len(stdin_buffer) > 0:
if fd == p.stdin.fileno() and len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0:
n = os.write(fd, stdout_buffer)
stdout_buffer = stdout_buffer[n:]
elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0:
n = os.write(fd, stderr_buffer)
stderr_buffer = stderr_buffer[n:]
if stop:
p.terminate()
break

View File

@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
def _nearest_indices(N, scale, dim, ndims):
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
M = int(scale * N)
indices = mx.arange(M, dtype=mx.float32)
if M > N:
indices = (indices + 0.5) * (N / M) - 0.5
indices = indices.round()
else:
indices = indices * (N / M)
shape = [1] * ndims
shape[dim] = -1
return indices.astype(mx.uint32).reshape(shape)
def _linear_indices(N, scale, align_corners, dim, ndims):

View File

@ -319,6 +319,18 @@ void init_array(nb::module_& m) {
R"pbdoc(
The array's :class:`Dtype`.
)pbdoc")
.def_prop_ro(
"real",
[](const mx::array& a) { return mx::real(a); },
R"pbdoc(
The real part of a complex array.
)pbdoc")
.def_prop_ro(
"imag",
[](const mx::array& a) { return mx::imag(a); },
R"pbdoc(
The imaginary part of a complex array.
)pbdoc")
.def(
"item",
&to_scalar,

View File

@ -236,7 +236,7 @@ void init_linalg(nb::module_& parent_module) {
Returns:
Union[tuple(array, ...), array]:
If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that
If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that
``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``.
)pbdoc");
m.def(
@ -407,6 +407,76 @@ void init_linalg(nb::module_& parent_module) {
Returns:
array: The cross product of ``a`` and ``b`` along the specified axis.
)pbdoc");
m.def(
"eigvals",
&mx::linalg::eigvals,
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"pbdoc(
Compute the eigenvalues of a square matrix.
This function differs from :func:`numpy.linalg.eigvals` in that the
return type is always complex even if the eigenvalues are all real.
This function supports arrays with at least 2 dimensions. When the
input has more than two dimensions, the eigenvalues are computed for
each matrix in the last two dimensions.
Args:
a (array): The input array.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The eigenvalues (not necessarily in order).
Example:
>>> A = mx.array([[1., -2.], [-2., 1.]])
>>> eigenvalues = mx.linalg.eigvals(A, stream=mx.cpu)
>>> eigenvalues
array([3+0j, -1+0j], dtype=complex64)
)pbdoc");
m.def(
"eig",
[](const mx::array& a, mx::StreamOrDevice s) {
auto result = mx::linalg::eig(a, s);
return nb::make_tuple(result.first, result.second);
},
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"pbdoc(
Compute the eigenvalues and eigenvectors of a square matrix.
This function differs from :func:`numpy.linalg.eig` in that the
return type is always complex even if the eigenvalues are all real.
This function supports arrays with at least 2 dimensions. When the input
has more than two dimensions, the eigenvalues and eigenvectors are
computed for each matrix in the last two dimensions.
Args:
a (array): The input array.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
Tuple[array, array]:
A tuple containing the eigenvalues and the normalized right
eigenvectors. The column ``v[:, i]`` is the eigenvector
corresponding to the i-th eigenvalue.
Example:
>>> A = mx.array([[1., -2.], [-2., 1.]])
>>> w, v = mx.linalg.eig(A, stream=mx.cpu)
>>> w
array([3+0j, -1+0j], dtype=complex64)
>>> v
array([[0.707107+0j, 0.707107+0j],
[-0.707107+0j, 0.707107+0j]], dtype=complex64)
)pbdoc");
m.def(
"eigvalsh",
&mx::linalg::eigvalsh,

View File

@ -49,21 +49,21 @@ void init_metal(nb::module_& m) {
metal.def(
"set_memory_limit",
[](size_t limit) {
DEPRECATE("mx.metal.set_memory_limt", "mx.set_memory_limit");
DEPRECATE("mx.metal.set_memory_limit", "mx.set_memory_limit");
return mx::set_memory_limit(limit);
},
"limit"_a);
metal.def(
"set_cache_limit",
[](size_t limit) {
DEPRECATE("mx.metal.set_cache_limt", "mx.set_cache_limit");
DEPRECATE("mx.metal.set_cache_limit", "mx.set_cache_limit");
return mx::set_cache_limit(limit);
},
"limit"_a);
metal.def(
"set_wired_limit",
[](size_t limit) {
DEPRECATE("mx.metal.set_wired_limt", "mx.set_wired_limit");
DEPRECATE("mx.metal.set_wired_limit", "mx.set_wired_limit");
return mx::set_wired_limit(limit);
},
"limit"_a);

Some files were not shown because too many files have changed in this diff Show More