Add load_safe to the general conv loaders (#2258)

This commit is contained in:
Angelos Katharopoulos 2025-06-10 20:58:16 -07:00 committed by GitHub
parent 095163b8d1
commit 8590c0941e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 302 additions and 22 deletions

View File

@ -0,0 +1,107 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -391,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu(
// Get channel iteration info
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int gemm_k_iters = channel_k_iters;
bool align_C = conv_params.C % bk == 0;
// Fix host side helper params
int sign = (conv_params.flip ? -1 : 1);
@ -419,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu(
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
std::string kname;
kname.reserve(64);
concatenate(
kname,
"implicit_gemm_conv_2d_general_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
std::string hash_name;
hash_name.reserve(64);
concatenate(hash_name, kname, "_alC_", align_C);
metal::MTLFCList func_consts = {
{&align_C, MTL::DataType::DataTypeBool, 200},
};
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
auto kernel = get_steel_conv_general_kernel(
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions
@ -728,8 +748,10 @@ void dispatch_conv_2D_gpu(
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
bool out_large =
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
@ -743,7 +765,7 @@ void dispatch_conv_2D_gpu(
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}

View File

@ -727,6 +727,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
int bm,
int bn,
@ -749,7 +751,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
wn);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_fft_kernel(

View File

@ -205,6 +205,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
int bm,
int bn,

View File

@ -2,6 +2,8 @@
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
constant bool align_C [[function_constant(200)]];
template <
typename T,
int BM,
@ -118,6 +120,7 @@ implicit_gemm_conv_2d_general(
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
if (align_C) {
int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
@ -136,6 +139,40 @@ implicit_gemm_conv_2d_general(
loader_a.next();
loader_b.next();
}
}
else {
for (int k = 1; k < gemm_params->gemm_k_iterations; k++) {
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
const short remaining_k = params->C % BK;
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
// Load elements into threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(remaining_k);
loader_b.load_safe(remaining_k);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
threadgroup_barrier(mem_flags::mem_none);

View File

@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral {
}
}
METAL_FUNC void load_safe(const short remaining_k) const {
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Find bounds
int n = read_n[i];
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
int ih = ih_dil / params->idil[0];
int iw = iw_dil / params->idil[1];
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
// Read from input if in bounds
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
(iw_dil >= 0 && iw < params->iS[1])) {
if (bj + vec_size <= remaining_k) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = (src[i])[offset + j];
}
} else {
for (short j = 0; j < vec_size; ++j) {
if (bj + j < remaining_k) {
dst[is * dst_ld + j] = (src[i])[offset + j];
} else {
dst[is * dst_ld + j] = T(0);
}
}
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_w += jump_params->f_wgt_jump_w;
@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral {
}
}
METAL_FUNC void load_safe(const short remaining_k) const {
const device T* curr_src = src + weight_h * params->wt_strides[1] +
weight_w * params->wt_strides[2];
if ((start_row + BN <= params->O)) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BN; i += TROWS) {
if (bj + vec_size <= remaining_k) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
} else {
for (short j = 0; j < vec_size; j++) {
if (bj + j < remaining_k) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
} else {
dst[i * dst_ld + j] = T(0);
}
}
}
}
} else {
for (short i = 0; i < BN; i += TROWS) {
if ((start_row + i) < params->O) {
if (bj + vec_size <= remaining_k) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
} else {
for (short j = 0; j < vec_size; j++) {
if (bj + j < remaining_k) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
} else {
dst[i * dst_ld + j] = T(0);
}
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_w += jump_params->f_wgt_jump_w;

View File

@ -244,13 +244,15 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
int,
int,
int,
int,
int) {
return d.get_kernel(kernel_name);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_fft_kernel(

View File

@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
def test_conv2d_unaligned_channels(self):
x = mx.random.uniform(shape=(2, 16, 16, 21))
w = mx.random.uniform(shape=(32, 3, 3, 21))
y = mx.conv2d(x, w, stream=mx.cpu)
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
x = mx.random.uniform(shape=(2, 16, 16, 21))
w = mx.random.uniform(shape=(21, 3, 3, 21))
y = mx.conv2d(x, w, stream=mx.cpu)
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
if __name__ == "__main__":
unittest.main()