Merge branch 'ml-explore:main' into main

This commit is contained in:
Dhruv Srikanth 2025-05-10 16:18:35 +01:00 committed by GitHub
commit 3d79254682
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 2012 additions and 289 deletions

View File

@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON) option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@ -83,6 +84,10 @@ if(MLX_BUILD_METAL)
set(QUARTZ_LIB "-framework QuartzCore") set(QUARTZ_LIB "-framework QuartzCore")
endif() endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB) if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU") message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)

View File

@ -47,10 +47,18 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else() else()
target_sources(mlx target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
endif() endif()

View File

@ -22,7 +22,8 @@ void slow_conv_1D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -60,7 +61,8 @@ void slow_conv_1D(
out_stride_O = out.strides()[2], out_stride_O = out.strides()[2],
flip, flip,
padding = padding[0], padding_lo = padding_lo[0],
padding_hi = padding_hi[0],
wt_stride = wt_strides[0], wt_stride = wt_strides[0],
wt_dilation = wt_dilation[0], wt_dilation = wt_dilation[0],
in_dilation = in_dilation[0]]() mutable { in_dilation = in_dilation[0]]() mutable {
@ -77,7 +79,7 @@ void slow_conv_1D(
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh; int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_stride - padding + wh_flip * wt_dilation; int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
auto ih_div = std::div(ih, in_dilation); auto ih_div = std::div(ih, in_dilation);
@ -109,7 +111,8 @@ void slow_conv_2D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -120,230 +123,235 @@ void slow_conv_2D(
encoder.set_input_array(wt); encoder.set_input_array(wt);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.dispatch([st_wt_ptr = wt.data<T>(), encoder.dispatch(
st_in_ptr = in.data<T>(), [st_wt_ptr = wt.data<T>(),
st_out_ptr = out.data<T>(), st_in_ptr = in.data<T>(),
st_out_ptr = out.data<T>(),
N = in.shape( N = in.shape(0), // Batch size, should be the same as out.shape(0)
0), // Batch size, should be the same as out.shape(0) iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
iH = 1 + iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim C = in.shape(3), // In channels
iW = 1 + oH = out.shape(1), // Output spatial dim
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim oW = out.shape(2), // Output spatial dim
C = in.shape(3), // In channels O = wt.shape(0), // Out channels
oH = out.shape(1), // Output spatial dim wH = wt.shape(1), // Weight spatial dim
oW = out.shape(2), // Output spatial dim wW = wt.shape(2), // Weight spatial dim
O = wt.shape(0), // Out channels
wH = wt.shape(1), // Weight spatial dim
wW = wt.shape(2), // Weight spatial dim
groups = in.shape(3) / wt.shape(3), groups = in.shape(3) / wt.shape(3),
C_per_group = wt.shape(3), C_per_group = wt.shape(3),
in_stride_N = in.strides()[0], in_stride_N = in.strides()[0],
in_stride_H = in.strides()[1], in_stride_H = in.strides()[1],
in_stride_W = in.strides()[2], in_stride_W = in.strides()[2],
in_stride_C = in.strides()[3], in_stride_C = in.strides()[3],
wt_stride_O = wt.strides()[0], wt_stride_O = wt.strides()[0],
wt_stride_H = wt.strides()[1], wt_stride_H = wt.strides()[1],
wt_stride_W = wt.strides()[2], wt_stride_W = wt.strides()[2],
wt_stride_C = wt.strides()[3], wt_stride_C = wt.strides()[3],
out_stride_N = out.strides()[0], out_stride_N = out.strides()[0],
out_stride_H = out.strides()[1], out_stride_H = out.strides()[1],
out_stride_W = out.strides()[2], out_stride_W = out.strides()[2],
out_stride_O = out.strides()[3], out_stride_O = out.strides()[3],
padding, padding_lo,
wt_strides, padding_hi,
wt_dilation, wt_strides,
in_dilation, wt_dilation,
flip]() mutable { in_dilation,
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; flip]() mutable {
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
const int O_per_group = O / groups; const int O_per_group = O / groups;
auto pt_conv_no_checks = [&](const T* in_ptr, auto pt_conv_no_checks =
const T* wt_ptr, [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
T* out_ptr, out_ptr += oh * out_stride_H + ow * out_stride_W;
int oh, int ih_base = oh * wt_strides[0] - padding_lo[0];
int ow) { int iw_base = ow * wt_strides[1] - padding_lo[1];
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.; float r = 0.;
for (int wh = 0; wh < wH; ++wh) { for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) { for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh; int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww; int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; const T* wt_ptr_pt =
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { for (int c = g * C_per_group; c < (g + 1) * C_per_group;
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) * ++c) {
static_cast<float>( r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
wt_ptr_pt[(c % C_per_group) * wt_stride_C]); static_cast<float>(
} // c wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // ww } // c
} // wh } // ww
} // wh
out_ptr[0] = static_cast<T>(r); out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O; out_ptr += out_stride_O;
wt_ptr += wt_stride_O; wt_ptr += wt_stride_O;
} // o } // o
} // g } // g
}; };
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h = int f_wgt_jump_h =
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w = int f_wgt_jump_w =
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; int f_out_jump_h =
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w =
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h); std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w); std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) { for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h; int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
int wh_base = 0; int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) { while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++; wh_base++;
ih_loop += jump_h; ih_loop += jump_h;
} }
base_h[i] = wh_base; base_h[i] = wh_base;
} }
for (int j = 0; j < f_out_jump_w; ++j) { for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w; int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
int ww_base = 0; int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) { while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++; ww_base++;
iw_loop += jump_w; iw_loop += jump_w;
} }
base_w[j] = ww_base; base_w[j] = ww_base;
} }
auto pt_conv_all_checks = auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W; out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0]; int ih_base = oh * wt_strides[0] - padding_lo[0];
int iw_base = ow * wt_strides[1] - padding[1]; int iw_base = ow * wt_strides[1] - padding_lo[1];
int wh_base = base_h[oh % f_out_jump_h]; int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w]; int ww_base = base_w[ow % f_out_jump_w];
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.; float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh; int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww; int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt = const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W; wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt = const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; iw_dil * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group; for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) { ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) * r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>( static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]); wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c } // c
} // ih, iw check } // ih, iw check
} // ww } // ww
} // wh } // wh
out_ptr[0] = static_cast<T>(r); out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O; out_ptr += out_stride_O;
wt_ptr += wt_stride_O; wt_ptr += wt_stride_O;
} // o } // o
} // g } // g
}; };
int oH_border_0 = 0; int oH_border_0 = 0;
int oH_border_1 = int oH_border_1 = is_idil_one
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
int oH_border_2 = std::max( : oH;
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); int oH_border_2 = std::max(
int oH_border_3 = oH; oH_border_1,
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oW_border_0 = 0; int oW_border_0 = 0;
int oW_border_1 = int oW_border_1 = is_idil_one
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
int oW_border_2 = std::max( : oW;
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); int oW_border_2 = std::max(
int oW_border_3 = oW; oW_border_1,
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds // Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) { for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) { for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
} // oh } // oh
// Case 2: oh in bounds // Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) { for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds // Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) { for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
// Case b: ow in bounds // Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) { for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
// Case c: ow might put us out of bounds // Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) { for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
} // oh } // oh
// Case 3: oh might put us out of bounds // Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) { for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) { for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
} // oh } // oh
st_in_ptr += in_stride_N; st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N; st_out_ptr += out_stride_N;
} // n } // n
}); });
} }
template <typename T> template <typename T>
@ -351,7 +359,8 @@ void slow_conv_3D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -400,7 +409,8 @@ void slow_conv_3D(
out_stride_H = out.strides()[2], out_stride_H = out.strides()[2],
out_stride_W = out.strides()[3], out_stride_W = out.strides()[3],
out_stride_O = out.strides()[4], out_stride_O = out.strides()[4],
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -415,9 +425,9 @@ void slow_conv_3D(
int oh, int oh,
int ow) { int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0]; int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding[1]; int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding[2]; int iw_base = ow * wt_strides[2] - padding_lo[2];
for (int o = 0; o < O; ++o) { for (int o = 0; o < O; ++o) {
float r = 0.; float r = 0.;
@ -478,7 +488,7 @@ void slow_conv_3D(
std::vector<int> base_w(f_out_jump_w); std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) { for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding[0] + init_d; int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
int wd_base = 0; int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) { while (wd_base < wD && id_loop % in_dilation[0] != 0) {
@ -490,7 +500,7 @@ void slow_conv_3D(
} }
for (int i = 0; i < f_out_jump_h; ++i) { for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding[1] + init_h; int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
int wh_base = 0; int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) { while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
@ -502,7 +512,7 @@ void slow_conv_3D(
} }
for (int j = 0; j < f_out_jump_w; ++j) { for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding[2] + init_w; int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
int ww_base = 0; int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) { while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
@ -521,9 +531,9 @@ void slow_conv_3D(
int ow) { int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0]; int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding[1]; int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding[2]; int iw_base = ow * wt_strides[2] - padding_lo[2];
int wd_base = base_d[od % f_out_jump_d]; int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h]; int wh_base = base_h[oh % f_out_jump_h];
@ -573,24 +583,30 @@ void slow_conv_3D(
}; };
int oD_border_0 = 0; int oD_border_0 = 0;
int oD_border_1 = int oD_border_1 = is_idil_one
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
: oD;
int oD_border_2 = std::max( int oD_border_2 = std::max(
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); oD_border_1,
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD; int oD_border_3 = oD;
int oH_border_0 = 0; int oH_border_0 = 0;
int oH_border_1 = int oH_border_1 = is_idil_one
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
: oH;
int oH_border_2 = std::max( int oH_border_2 = std::max(
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); oH_border_1,
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH; int oH_border_3 = oH;
int oW_border_0 = 0; int oW_border_0 = 0;
int oW_border_1 = int oW_border_1 = is_idil_one
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
: oW;
int oW_border_2 = std::max( int oW_border_2 = std::max(
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); oW_border_1,
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW; int oW_border_3 = oW;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
@ -658,7 +674,8 @@ void dispatch_slow_conv_1D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -669,7 +686,8 @@ void dispatch_slow_conv_1D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -680,7 +698,8 @@ void dispatch_slow_conv_1D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -691,7 +710,8 @@ void dispatch_slow_conv_1D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -707,7 +727,8 @@ void dispatch_slow_conv_2D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -718,7 +739,8 @@ void dispatch_slow_conv_2D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -729,7 +751,8 @@ void dispatch_slow_conv_2D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -740,7 +763,8 @@ void dispatch_slow_conv_2D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -756,7 +780,8 @@ void dispatch_slow_conv_3D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -767,7 +792,8 @@ void dispatch_slow_conv_3D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -778,7 +804,8 @@ void dispatch_slow_conv_3D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -789,7 +816,8 @@ void dispatch_slow_conv_3D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
Stream stream) { Stream stream) {
@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu(
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
// Pad input // Pad input
Shape padded_shape = {N, iH + 2 * padding[0], C}; Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros // Fill with zeros
@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream); copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = padding[0] * in_padded.strides()[1]; size_t data_offset = padding_lo[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer( in_padded_slice.copy_shared_buffer(
in_padded, in_padded,
@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
Stream stream) { Stream stream) {
@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu(
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
// Pad input // Pad input
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros // Fill with zeros
@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream); copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer( in_padded_slice.copy_shared_buffer(
in_padded, in_padded,
@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const bool flip, const bool flip,
@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu(
Shape padded_shape(in.shape().size()); Shape padded_shape(in.shape().size());
padded_shape.front() = N; padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) { for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i]; padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
} }
padded_shape.back() = C; padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu(
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = 0; size_t data_offset = 0;
for (size_t i = 0; i < padding.size(); i++) { for (size_t i = 0; i < padding_lo.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1]; data_offset += padding_lo[i] * in_padded.strides()[i + 1];
} }
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer( in_padded_slice.copy_shared_buffer(
in_padded, in_padded,
@ -1261,7 +1297,8 @@ void conv_1D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -1270,22 +1307,40 @@ void conv_1D_cpu(
const int groups = in.shape().back() / wt.shape().back(); const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu( return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation, stream); in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
} }
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu( return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
} }
return dispatch_slow_conv_1D( return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
} }
void conv_2D_cpu( void conv_2D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -1295,18 +1350,35 @@ void conv_2D_cpu(
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) { in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu( return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
} }
return dispatch_slow_conv_2D( return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
} }
void conv_3D_cpu( void conv_3D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -1317,11 +1389,28 @@ void conv_3D_cpu(
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) { groups == 1) {
return explicit_gemm_conv_ND_cpu( return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
} }
return dispatch_slow_conv_3D( return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
} }
} // namespace } // namespace
@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,

View File

@ -0,0 +1,57 @@
# Filename rules in cuda backend:
#
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
# * Device-only kernel code should be put in kernels/ subdir.
# * Files in kernels/ subdir should not include files outside.
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
# Enable defining device lambda functions.
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"75;80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
# Use fixed version of CCCL.
FetchContent_Declare(
cccl
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
nvtx3
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
GIT_TAG v3.1.1
GIT_SHALLOW TRUE
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(nvtx3)
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
# Make cuda runtime APIs available in non-cuda files.
find_package(CUDAToolkit REQUIRED)
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)

View File

@ -0,0 +1,154 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
#include <cassert>
namespace mlx::core {
namespace cu {
CudaAllocator::CudaAllocator() {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
}
Buffer CudaAllocator::malloc(size_t size) {
// TODO: Check memory limit.
auto* buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
std::lock_guard lock(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{buf};
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return;
}
// If free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([buffer]() { allocator().free(buffer); });
worker_->end_batch();
worker_->commit();
return;
}
}
size_t size = buf->size;
cudaFree(buf->data);
delete buf;
std::lock_guard lock(mutex_);
active_memory_ -= size;
}
size_t CudaAllocator::size(Buffer buffer) const {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return 0;
}
return buf->size;
}
void CudaAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
size_t CudaAllocator::get_active_memory() const {
return active_memory_;
}
size_t CudaAllocator::get_peak_memory() const {
return peak_memory_;
}
void CudaAllocator::reset_peak_memory() {
std::lock_guard lock(mutex_);
peak_memory_ = 0;
}
size_t CudaAllocator::get_memory_limit() {
return memory_limit_;
}
size_t CudaAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_);
return limit;
}
CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static CudaAllocator* allocator_ = new CudaAllocator;
return *allocator_;
}
} // namespace cu
namespace allocator {
Allocator& allocator() {
return cu::allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
}
} // namespace allocator
size_t get_active_memory() {
return cu::allocator().get_active_memory();
}
size_t get_peak_memory() {
return cu::allocator().get_peak_memory();
}
void reset_peak_memory() {
return cu::allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return cu::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return cu::allocator().get_memory_limit();
}
// TODO: Implement buffer cache.
size_t get_cache_memory() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@ -0,0 +1,58 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::cu {
class Worker;
using allocator::Buffer;
// Stores cuda-managed unified memory.
struct CudaBuffer {
void* data;
size_t size;
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
private:
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
};
CudaAllocator& allocator();
} // namespace mlx::core::cu

26
mlx/backend/cuda/copy.cpp Normal file
View File

@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& data_shape,
const Strides& strides_in_pre,
const Strides& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
}
void fill_gpu(const array& val, array& out, const Stream& s) {
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

117
mlx/backend/cuda/device.cpp Normal file
View File

@ -0,0 +1,117 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/metal/metal.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
void DeviceStream::synchronize() {
cudaStreamSynchronize(stream_);
}
cudaStream_t DeviceStream::schedule_cuda_stream() {
// TODO: Return a stream that maximizes parallelism.
return stream_;
}
cudaStream_t DeviceStream::last_cuda_stream() {
return stream_;
}
CommandEncoder& DeviceStream::get_encoder() {
if (!encoder_) {
encoder_ = std::make_unique<CommandEncoder>(*this);
}
return *encoder_;
}
Device::Device(int device) : device_(device) {
// Validate the requirements of device.
int attr = 0;
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
if (attr != 1) {
throw std::runtime_error(fmt::format(
"Device {} does not support synchronization in managed memory.",
device_));
}
}
void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host.
static int current = 0;
if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_;
}
}
DeviceStream& Device::get_stream(Stream s) {
auto it = streams_.find(s.index);
if (it == streams_.end()) {
it = streams_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CommandEncoder(DeviceStream& s)
: device_(s.device()), stream_(s) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
void CommandEncoder::end_encoding() {
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
// There is no kernel running, run completion handlers immediately.
if (!has_gpu_work_) {
worker_.consume_in_this_thread();
return;
}
has_gpu_work_ = false;
// Put completion handlers in a batch.
worker_.end_batch();
// Signaling kernel completion is expensive, delay until enough batches.
// TODO: This number is arbitrarily picked, profile for a better stragety.
if (worker_.uncommited_batches() > 8) {
commit();
}
}
void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream());
}
Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index);
if (it == devices.end()) {
it = devices.try_emplace(device.index, device.index).first;
}
return it->second;
}
DeviceStream& get_stream(Stream s) {
return device(s.device).get_stream(s);
}
CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder();
}
} // namespace cu
} // namespace mlx::core

131
mlx/backend/cuda/device.h Normal file
View File

@ -0,0 +1,131 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
#include <thrust/execution_policy.h>
#include <unordered_map>
namespace mlx::core::cu {
class Device;
class CommandEncoder;
class DeviceStream {
public:
explicit DeviceStream(Device& device);
DeviceStream(const DeviceStream&) = delete;
DeviceStream& operator=(const DeviceStream&) = delete;
// Wait until kernels in the stream complete.
void synchronize();
// Return a cuda stream for launching kernels.
cudaStream_t schedule_cuda_stream();
// Return the last cuda stream used.
cudaStream_t last_cuda_stream();
CommandEncoder& get_encoder();
Device& device() {
return device_;
}
private:
Device& device_;
CudaStream stream_;
std::unique_ptr<CommandEncoder> encoder_;
};
class Device {
public:
explicit Device(int device);
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls.
void make_current();
DeviceStream& get_stream(Stream s);
int cuda_device() const {
return device_;
}
private:
int device_;
std::unordered_map<int, DeviceStream> streams_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a cuda stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(cudaStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_cuda_error("kernel launch", cudaGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
};
Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s);
// Return an execution policy that does not sync for result.
// Note that not all thrust APIs support async policy, confirm before using.
inline auto thrust_policy(cudaStream_t stream) {
// TODO: Connect thrust's custom allocator with mlx's allocator.
return thrust::cuda::par_nosync.on(stream);
}
} // namespace mlx::core::cu

View File

@ -0,0 +1,35 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core {
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
using type = T;
};
template <>
struct CTypeToCudaType<float16_t> {
using type = __half;
};
template <>
struct CTypeToCudaType<bfloat16_t> {
using type = __nv_bfloat16;
};
template <>
struct CTypeToCudaType<complex64_t> {
using type = cuComplex;
};
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
} // namespace mlx::core

68
mlx/backend/cuda/eval.cpp Normal file
View File

@ -0,0 +1,68 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu {
bool is_available() {
return true;
}
void new_stream(Stream s) {
// Force initalization of cuda, so cuda runtime get destroyed at last.
cudaFree(nullptr);
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
}
void eval(array& arr) {
nvtx3::scoped_range r("gpu::eval");
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
if (encoder.has_gpu_work()) {
// Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
}
encoder.end_encoding();
}
void finalize(Stream s) {
nvtx3::scoped_range r("gpu::finalize");
cu::get_command_encoder(s).commit();
}
void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize");
cu::get_stream(s).synchronize();
}
} // namespace mlx::core::gpu

265
mlx/backend/cuda/event.cu Normal file
View File

@ -0,0 +1,265 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
///////////////////////////////////////////////////////////////////////////////
// CudaEvent implementations
///////////////////////////////////////////////////////////////////////////////
// Cuda event managed with RAII.
class CudaEventHandle {
public:
CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
}
~CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
}
CudaEventHandle(const CudaEventHandle&) = delete;
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
operator cudaEvent_t() const {
return event_;
}
private:
cudaEvent_t event_;
};
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaEventSynchronize(*event_);
}
void CudaEvent::wait(cudaStream_t stream) {
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
wait(cu::get_stream(s).last_cuda_stream());
}
}
void CudaEvent::record(cudaStream_t stream) {
cudaEventRecord(*event_, stream);
recorded_ = true;
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
record(cu::get_stream(s).last_cuda_stream());
}
}
bool CudaEvent::completed() const {
return cudaEventQuery(*event_) == cudaSuccess;
}
///////////////////////////////////////////////////////////////////////////////
// SharedEvent implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
uint64_t current;
while ((current = ac->load()) < value) {
ac->wait(current);
}
}
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
ac->store(value);
ac->notify_all();
}
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value);
}
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value);
}
} // namespace
SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory.
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
ptr->~Atomic();
allocator::free(buffer);
});
}
void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(ac_.get(), value);
}
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { wait(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(ac_.get(), value);
}
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { signal(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return ac_->load() >= value;
}
uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value");
return ac_->load();
}
} // namespace cu
///////////////////////////////////////////////////////////////////////////////
// Event implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have
// to fallback to SharedEvent in following cases:
// 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CudaEvent> cuda;
std::unique_ptr<cu::SharedEvent> shared;
bool is_created() const {
return cuda || shared;
}
void ensure_created(Stream s, uint64_t signal_value) {
if (is_created()) {
return;
}
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow SharedEvent");
shared = std::make_unique<cu::SharedEvent>();
} else {
cuda = std::make_unique<cu::CudaEvent>();
}
}
};
} // namespace
Event::Event(Stream s) : stream_(s) {
event_ = std::shared_ptr<void>(
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
}
void Event::wait() {
auto* event = static_cast<EventImpl*>(event_.get());
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait();
} else {
event->shared->wait(value());
}
}
void Event::wait(Stream s) {
auto* event = static_cast<EventImpl*>(event_.get());
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait(s);
} else {
event->shared->wait(s, value());
}
}
void Event::signal(Stream s) {
auto* event = static_cast<EventImpl*>(event_.get());
event->ensure_created(s, value());
if (event->cuda) {
assert(value() == 1);
event->cuda->record(s);
} else {
event->shared->signal(s, value());
}
}
bool Event::is_signaled() const {
auto* event = static_cast<EventImpl*>(event_.get());
if (!event->is_created()) {
return false;
}
if (event->cuda) {
assert(value() == 1);
return event->cuda->recorded() && event->cuda->completed();
} else {
return event->shared->is_signaled(value());
}
}
} // namespace mlx::core

66
mlx/backend/cuda/event.h Normal file
View File

@ -0,0 +1,66 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/stream.h"
#include <cuda_runtime.h>
#include <cuda/atomic>
#include <memory>
namespace mlx::core::cu {
class CudaEventHandle;
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
CudaEvent();
void wait();
void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called.
bool completed() const;
bool recorded() const {
return recorded_;
}
private:
bool recorded_{false};
std::shared_ptr<CudaEventHandle> event_;
};
// Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible.
class SharedEvent {
public:
using Atomic = cuda::atomic<uint64_t>;
SharedEvent();
void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value);
void wait(Stream s, uint64_t value);
void signal(uint64_t value);
void signal(cudaStream_t stream, uint64_t value);
void signal(Stream s, uint64_t value);
bool is_signaled(uint64_t value) const;
uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private:
std::shared_ptr<Atomic> ac_;
};
} // namespace mlx::core::cu

70
mlx/backend/cuda/fence.cu Normal file
View File

@ -0,0 +1,70 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
while (true) {
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
// it the load() may never return new value.
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
uint64_t current = ac->load();
if (current >= value) {
break;
}
}
}
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
busy_wait(ac, value);
}
} // namespace
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
// https://github.com/ml-explore/mlx/issues/2137
const auto& ac = fence->event.atomic();
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [ac, count = fence->count]() {
nvtx3::scoped_range r("Fence::wait()");
busy_wait(ac.get(), count);
});
} else {
nvtx3::scoped_range r("Fence::wait(s)");
auto& encoder = cu::get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
});
encoder.add_completed_handler([ac]() {});
encoder.end_encoding();
}
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace mlx::core::cu

View File

@ -0,0 +1,107 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_fp16.h>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
namespace mlx::core::cu {
///////////////////////////////////////////////////////////////////////////////
// Missing C++ operator overrides for CUDA 7.
///////////////////////////////////////////////////////////////////////////////
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
#define MLX_DEFINE_BF16_OP(OP) \
__forceinline__ __device__ __nv_bfloat16 operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
}
#define MLX_DEFINE_BF16_CMP(OP) \
__forceinline__ __device__ bool operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
}
MLX_DEFINE_BF16_OP(+)
MLX_DEFINE_BF16_OP(-)
MLX_DEFINE_BF16_OP(*)
MLX_DEFINE_BF16_OP(/)
MLX_DEFINE_BF16_CMP(>)
MLX_DEFINE_BF16_CMP(<)
MLX_DEFINE_BF16_CMP(>=)
MLX_DEFINE_BF16_CMP(<=)
#undef MLX_DEFINE_BF16_OP
#undef MLX_DEFINE_BF16_CMP
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
///////////////////////////////////////////////////////////////////////////////
// Additional C++ operator overrides between half types and native types.
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U>
constexpr bool is_integral_except =
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
template <typename T, typename U>
constexpr bool is_arithmetic_except =
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
} \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
}
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
return HALF2FLOAT(x) OP static_cast<float>(y); \
} \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
return static_cast<float>(y) OP HALF2FLOAT(x); \
}
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
#undef MLX_DEFINE_HALF_OP
#undef MLX_DEFINE_HALF_CMP
} // namespace mlx::core::cu

View File

@ -0,0 +1,163 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/dtype_utils.cuh"
#include "mlx/backend/cuda/kernels/arange.cuh"
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cassert>
namespace mlx::core {
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_output_array(out);
encoder.launch_kernel([&, this](cudaStream_t stream) {
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
});
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(AddMM)
NO_GPU(ArcCos)
NO_GPU(ArcCosh)
NO_GPU(ArcSin)
NO_GPU(ArcSinh)
NO_GPU(ArcTan)
NO_GPU(ArcTan2)
NO_GPU(ArcTanh)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(BitwiseBinary)
NO_GPU(BitwiseInvert)
NO_GPU(BlockMaskedMM)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Conjugate)
NO_GPU(Convolution)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Floor)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Imag)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
NO_GPU(Multiply)
NO_GPU(Negative)
NO_GPU(NotEqual)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Real)
NO_GPU(Reduce)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)
NO_GPU(Sinh)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/slicing.h"
namespace mlx::core {
void concatenate_gpu(
const std::vector<array>& inputs,
array& out,
int axis,
const Stream& s) {
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

View File

@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/device.h"
#include <fmt/format.h>
namespace mlx::core {
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
}
void check_cuda_error(const char* name, cudaError_t err) {
if (err != cudaSuccess) {
throw std::runtime_error(
fmt::format("{} failed: {}", name, cudaGetErrorString(err)));
}
}
} // namespace mlx::core

36
mlx/backend/cuda/utils.h Normal file
View File

@ -0,0 +1,36 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_runtime.h>
namespace mlx::core {
namespace cu {
class Device;
}
// Cuda stream managed with RAII.
class CudaStream {
public:
explicit CudaStream(cu::Device& device);
~CudaStream();
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const {
return stream_;
}
private:
cudaStream_t stream_;
};
// Throw exception if the cuda API does not succeed.
void check_cuda_error(const char* name, cudaError_t err);
// The macro version that prints the command that failed.
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
} // namespace mlx::core

View File

@ -0,0 +1,90 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)),
worker_(&Worker::thread_fn, this) {}
Worker::~Worker() {
{
std::lock_guard lock(worker_mutex_);
stop_ = true;
}
worker_event_.signal(batch_ + 1);
worker_.join();
}
void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task));
}
void Worker::consume_in_this_thread() {
for (auto& task : pending_tasks_) {
task();
}
pending_tasks_.clear();
}
void Worker::end_batch() {
batch_++;
{
std::lock_guard lock(worker_mutex_);
worker_tasks_[batch_] = std::move(pending_tasks_);
}
uncommited_batches_++;
}
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
}
void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
// Signal the |worker_event_| in |signal_stream_| after the kernels in
// |stream_| finish running.
signal_event_.record(stream);
signal_event_.wait(signal_stream_);
worker_event_.signal(signal_stream_, batch_);
}
void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) {
uint64_t batch = worker_event_.value();
Tasks tasks;
{
std::lock_guard lock(worker_mutex_);
// Move tasks in signaled batches.
auto end = worker_tasks_.upper_bound(batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) {
tasks = std::move(it->second);
} else {
std::move(
it->second.begin(), it->second.end(), std::back_inserter(tasks));
}
}
worker_tasks_.erase(worker_tasks_.begin(), end);
}
for (auto& task : tasks) {
task();
}
worker_event_.wait(batch + 1);
}
}
} // namespace mlx::core::cu

68
mlx/backend/cuda/worker.h Normal file
View File

@ -0,0 +1,68 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include <functional>
#include <map>
#include <mutex>
#include <thread>
namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
public:
Worker();
~Worker();
Worker(const Worker&) = delete;
Worker& operator=(const Worker&) = delete;
// Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Put pending tasks in a batch.
void end_batch();
// Inform worker thread to run current batches now.
void commit();
// Inform worker thread to run current batches after kernels in |stream|
// finish running.
void commit(cudaStream_t stream);
// Return how many batches have been added but not committed yet.
size_t uncommited_batches() const {
return uncommited_batches_;
}
private:
void thread_fn();
uint64_t batch_{0};
size_t uncommited_batches_{0};
// Cuda stream and event for signaling kernel completion.
CudaStream signal_stream_;
CudaEvent signal_event_;
// Worker thread.
SharedEvent worker_event_;
std::thread worker_;
std::mutex worker_mutex_;
bool stop_{false};
// Tasks are put in |pending_tasks_| first, and then moved to
// |worker_tasks_| when end_batch() is called.
using Tasks = std::vector<std::function<void()>>;
Tasks pending_tasks_;
std::map<uint64_t, Tasks> worker_tasks_;
};
} // namespace mlx::core::cu

View File

@ -952,7 +952,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -967,7 +967,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -983,7 +983,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,

View File

@ -632,7 +632,7 @@ void fft_op(
func_consts.push_back(make_int(&rader_m, 3)); func_consts.push_back(make_int(&rader_m, 3));
// The overall number of FFTs we're going to compute for this input // The overall number of FFTs we're going to compute for this input
int size = out.dtype() == float32 ? out.size() : in.size(); size_t size = out.dtype() == float32 ? out.size() : in.size();
if (real && inverse && four_step_params.required) { if (real && inverse && four_step_params.required) {
size = out.size(); size = out.size();
} }
@ -659,8 +659,6 @@ void fft_op(
// We can perform 2 RFFTs at once so the batch size is halved. // We can perform 2 RFFTs at once so the batch size is halved.
batch_size = (batch_size + 2 - 1) / 2; batch_size = (batch_size + 2 - 1) / 2;
} }
int out_buffer_size = out.size();
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto in_type_str = in.dtype() == float32 ? "float" : "float2"; auto in_type_str = in.dtype() == float32 ? "float" : "float2";
auto out_type_str = out.dtype() == float32 ? "float" : "float2"; auto out_type_str = out.dtype() == float32 ? "float" : "float2";

View File

@ -98,7 +98,7 @@ struct ReadWriter {
} }
METAL_FUNC void load() const { METAL_FUNC void load() const {
int batch_idx = elem.x * grid.y * n; size_t batch_idx = size_t(elem.x * grid.y) * n;
short tg_idx = elem.y * grid.z + elem.z; short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2; short max_index = grid.y * n - 2;
@ -121,7 +121,7 @@ struct ReadWriter {
} }
METAL_FUNC void write() const { METAL_FUNC void write() const {
int batch_idx = elem.x * grid.y * n; size_t batch_idx = size_t(elem.x * grid.y) * n;
short tg_idx = elem.y * grid.z + elem.z; short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2; short max_index = grid.y * n - 2;
@ -144,7 +144,7 @@ struct ReadWriter {
// Padded IO for Bluestein's algorithm // Padded IO for Bluestein's algorithm
METAL_FUNC void load_padded(int length, const device float2* w_k) const { METAL_FUNC void load_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length; size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int fft_idx = elem.z; int fft_idx = elem.z;
int m = grid.z; int m = grid.z;
@ -161,7 +161,7 @@ struct ReadWriter {
} }
METAL_FUNC void write_padded(int length, const device float2* w_k) const { METAL_FUNC void write_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length; size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int fft_idx = elem.z; int fft_idx = elem.z;
int m = grid.z; int m = grid.z;
float2 inv_factor = {1.0f / n, -1.0f / n}; float2 inv_factor = {1.0f / n, -1.0f / n};
@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
template <> template <>
METAL_FUNC void ReadWriter<float, float2>::load() const { METAL_FUNC void ReadWriter<float, float2>::load() const {
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
threadgroup float2* seq_buf = buf + elem.y * n; threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes // No out of bounds accesses on odd batch sizes
@ -283,7 +283,8 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::write() const { METAL_FUNC void ReadWriter<float, float2>::write() const {
short n_over_2 = (n / 2) + 1; short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n; threadgroup float2* seq_buf = buf + elem.y * n;
int grid_index = elem.x * grid.y + elem.y; int grid_index = elem.x * grid.y + elem.y;
@ -317,7 +318,7 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::load_padded( METAL_FUNC void ReadWriter<float, float2>::load_padded(
int length, int length,
const device float2* w_k) const { const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n; threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes // No out of bounds accesses on odd batch sizes
@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
int length, int length,
const device float2* w_k) const { const device float2* w_k) const {
int length_over_2 = (length / 2) + 1; int length_over_2 = (length / 2) + 1;
int batch_idx = size_t batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1; threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y; int grid_index = elem.x * grid.y + elem.y;
@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
template <> template <>
METAL_FUNC void ReadWriter<float2, float>::load() const { METAL_FUNC void ReadWriter<float2, float>::load() const {
short n_over_2 = (n / 2) + 1; short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n; threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes // No out of bounds accesses on odd batch sizes
@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
int n_over_2 = (n / 2) + 1; int n_over_2 = (n / 2) + 1;
int length_over_2 = (length / 2) + 1; int length_over_2 = (length / 2) + 1;
int batch_idx = size_t batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n; threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes // No out of bounds accesses on odd batch sizes
@ -503,7 +505,7 @@ template <>
METAL_FUNC void ReadWriter<float2, float>::write_padded( METAL_FUNC void ReadWriter<float2, float>::write_padded(
int length, int length,
const device float2* w_k) const { const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1; threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y; int grid_index = elem.x * grid.y + elem.y;

View File

@ -3974,6 +3974,7 @@ array conv_general(
to_stream(s), to_stream(s),
stride, stride,
padding_lo, padding_lo,
padding_hi,
kernel_dilation, kernel_dilation,
input_dilation, input_dilation,
groups, groups,

View File

@ -1055,7 +1055,8 @@ array conv_weight_backward_patches(
const array& wt, const array& wt,
const array& cotan, const array& cotan,
const std::vector<int>& kernel_strides, const std::vector<int>& kernel_strides,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
StreamOrDevice s) { StreamOrDevice s) {
// Resolve Padded input shapes and strides // Resolve Padded input shapes and strides
Shape padding_starts(in.ndim(), 0); Shape padding_starts(in.ndim(), 0);
@ -1064,9 +1065,9 @@ array conv_weight_backward_patches(
// padded shape // padded shape
for (int i = 1; i < in.ndim() - 1; i++) { for (int i = 1; i < in.ndim() - 1; i++) {
in_padded_shape[i] += 2 * padding[i - 1]; in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1];
padding_ends[i] += padding[i - 1]; padding_ends[i] += padding_lo[i - 1];
padding_starts[i] += padding[i - 1]; padding_starts[i] += padding_lo[i - 1];
} }
// padded strides (contiguous) // padded strides (contiguous)
@ -1078,9 +1079,16 @@ array conv_weight_backward_patches(
// Pad input // Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0); std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1); std::iota(padded_axes.begin(), padded_axes.end(), 1);
Shape padding_(padding.begin(), padding.end()); Shape padding_lo_(padding_lo.begin(), padding_lo.end());
auto in_padded = pad( Shape padding_hi_(padding_hi.begin(), padding_hi.end());
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s); auto in_padded =
pad(in,
padded_axes,
padding_lo_,
padding_hi_,
array(0, in.dtype()),
"constant",
s);
// Resolve strided patches // Resolve strided patches
@ -1147,16 +1155,16 @@ std::vector<array> Convolution::vjp(
for (int a : argnums) { for (int a : argnums) {
// Grads for input // Grads for input
if (a == 0) { if (a == 0) {
std::vector<int> padding_lo = padding_; std::vector<int> padding_lo = padding_lo_;
std::vector<int> padding_hi = padding_; std::vector<int> padding_hi = padding_hi_;
for (int i = 0; i < padding_lo.size(); ++i) { for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_[i] - 1; padding_lo[i] = wt_size - padding_lo_[i] - 1;
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding_[i]; padding_hi[i] = in_size - out_size + padding_hi_[i];
} }
// Check for negative padding // Check for negative padding
@ -1226,18 +1234,12 @@ std::vector<array> Convolution::vjp(
if (no_dilation && !flip_ && groups_ == 1) { if (no_dilation && !flip_ && groups_ == 1) {
auto grad = conv_weight_backward_patches( auto grad = conv_weight_backward_patches(
in, wt, cotan, kernel_strides_, padding_, stream()); in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());
grads.push_back(grad); grads.push_back(grad);
} else { } else {
std::vector<int> padding_lo = padding_; std::vector<int> padding_lo = padding_lo_;
std::vector<int> padding_hi = padding_; std::vector<int> padding_hi = padding_hi_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
}
auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1); auto in_trans = group_transpose(in, -1, 0, -1);
@ -1283,7 +1285,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
in, in,
w, w,
kernel_strides_, kernel_strides_,
padding_, padding_lo_,
padding_hi_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
groups, groups,
@ -1332,7 +1335,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
bool Convolution::is_equivalent(const Primitive& other) const { bool Convolution::is_equivalent(const Primitive& other) const {
const Convolution& c_other = static_cast<const Convolution&>(other); const Convolution& c_other = static_cast<const Convolution&>(other);
return padding_ == c_other.padding_ && return padding_lo_ == c_other.padding_lo_ &&
padding_hi_ == c_other.padding_hi_ &&
kernel_strides_ == c_other.kernel_strides_ && kernel_strides_ == c_other.kernel_strides_ &&
kernel_dilation_ == c_other.kernel_dilation_ && kernel_dilation_ == c_other.kernel_dilation_ &&
input_dilation_ == c_other.input_dilation_ && input_dilation_ == c_other.input_dilation_ &&

View File

@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive {
explicit Convolution( explicit Convolution(
Stream stream, Stream stream,
const std::vector<int>& kernel_strides, const std::vector<int>& kernel_strides,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& kernel_dilation, const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation, const std::vector<int>& input_dilation,
const int groups = 1, const int groups = 1,
const bool flip = false) const bool flip = false)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
padding_(padding), padding_lo_(padding_lo),
padding_hi_(padding_hi),
kernel_strides_(kernel_strides), kernel_strides_(kernel_strides),
kernel_dilation_(kernel_dilation), kernel_dilation_(kernel_dilation),
input_dilation_(input_dilation), input_dilation_(input_dilation),
@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
auto state() const { auto state() const {
return std::make_tuple( return std::make_tuple(
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive {
} }
private: private:
std::vector<int> padding_; std::vector<int> padding_lo_;
std::vector<int> padding_hi_;
std::vector<int> kernel_strides_; std::vector<int> kernel_strides_;
std::vector<int> kernel_dilation_; std::vector<int> kernel_dilation_;
std::vector<int> input_dilation_; std::vector<int> input_dilation_;

View File

@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 25 #define MLX_VERSION_MINOR 25
#define MLX_VERSION_PATCH 1 #define MLX_VERSION_PATCH 2
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase):
atol=2e-5 if dtype == np.float32 else 5e-4, atol=2e-5 if dtype == np.float32 else 5e-4,
) )
@unittest.skipIf(not has_torch, "requires Torch")
def test_asymmetric_padding(self):
inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32)
kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32)
strides = (2, 2, 2)
pt_out = torch.conv3d(
torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)),
torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)),
stride=strides,
padding=2,
)
pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy()
mx_out = mx.conv_general(
mx.array(inputs),
mx.array(kernel),
stride=strides,
padding=([0, 0, 0], [1, 1, 1]),
)
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32)
kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32)
pt_out = torch.conv2d(
torch.permute(torch.tensor(inputs), (0, 3, 1, 2)),
torch.permute(torch.tensor(kernel), (0, 3, 1, 2)),
stride=1,
padding=(1, 0),
)
pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy()
mx_out = mx.conv_general(
mx.array(inputs),
mx.array(kernel),
stride=1,
padding=([0, 0], [1, 0]),
)
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -9,7 +9,7 @@ FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
set(METAL_TEST_SOURCES gpu_tests.cpp) set(METAL_TEST_SOURCES gpu_tests.cpp)
endif() endif()