mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU * Add GPU support * Parallelize inside metal kernel * clenaup * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * New unfold kernel + remove unused code * Remove copy and refactor * Update vjp and reuse steel gemm * Fixed groups on cpu * Fix metal validation --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
86f495985b
commit
c4a471c99d
123
benchmarks/python/conv1d_bench.py
Normal file
123
benchmarks/python/conv1d_bench.py
Normal file
@ -0,0 +1,123 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_1D(strides=1, padding=0, groups=1):
|
||||
def mx_conv_1D(a, b):
|
||||
ys = []
|
||||
for _ in range(N_iter_func):
|
||||
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_1D
|
||||
|
||||
|
||||
def make_pt_conv_1D(strides=1, padding=0, groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_1D(a, b):
|
||||
ys = []
|
||||
for _ in range(N_iter_func):
|
||||
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_1D
|
||||
|
||||
|
||||
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
|
||||
scale = 1.0 / math.sqrt(wH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_1D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_1D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv1d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 5, 32, 1, 2, 1),
|
||||
(4, 32, 32, 5, 32, 1, 2, 2),
|
||||
(4, 32, 32, 5, 32, 1, 2, 4),
|
||||
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||
(4, 32, 32, 5, 32, 1, 2, 16),
|
||||
(4, 32, 32, 5, 32, 1, 2, 32),
|
||||
(4, 32, 256, 5, 512, 1, 2, 2),
|
||||
(4, 32, 256, 5, 512, 1, 2, 128),
|
||||
(4, 32, 256, 5, 512, 1, 2, 256),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
|
||||
for N, iH, C, wH, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, iH, C, wH, O, strides, padding, np_dtype, groups
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@ -114,6 +114,15 @@ class array {
|
||||
return array_desc_->strides;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the stride of the corresponding dimension.
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
|
||||
/** Get the arrays data type. */
|
||||
Dtype dtype() const {
|
||||
return array_desc_->dtype;
|
||||
|
@ -38,11 +38,15 @@ void slow_conv_1D(
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int C = in.shape(2); // Input channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(2);
|
||||
const int C_per_group = wt.shape(2);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_H = in.strides()[1];
|
||||
const size_t in_stride_C = in.strides()[2];
|
||||
@ -57,7 +61,8 @@ void slow_conv_1D(
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int o = 0; o < O; ++o) {
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
|
||||
float r = 0.;
|
||||
|
||||
@ -70,10 +75,10 @@ void slow_conv_1D(
|
||||
auto ih_div = std::div(ih, in_dilation[0]);
|
||||
|
||||
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||
r += static_cast<float>(
|
||||
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
|
||||
static_cast<float>(wt_ptr[c * wt_stride_C]);
|
||||
static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
} // ih check
|
||||
@ -81,11 +86,11 @@ void slow_conv_1D(
|
||||
|
||||
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
|
||||
} // o
|
||||
} // g
|
||||
} // oh
|
||||
|
||||
in_ptr += in_stride_N;
|
||||
out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
}
|
||||
|
||||
@ -366,11 +371,15 @@ void explicit_gemm_conv_1D_cpu(
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int C = in.shape(2); // Input channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(2);
|
||||
const int C_per_group = wt.shape(2);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
@ -402,6 +411,11 @@ void explicit_gemm_conv_1D_cpu(
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2]};
|
||||
auto flags = in_padded.flags();
|
||||
if (groups > 1) {
|
||||
// Transpose the last two dimensions for grouped convolutions
|
||||
std::swap(strided_shape[2], strided_shape[3]);
|
||||
std::swap(strided_strides[2], strided_strides[3]);
|
||||
}
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
@ -416,7 +430,19 @@ void explicit_gemm_conv_1D_cpu(
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
if (groups > 1) {
|
||||
// Transpose the last two dimensions for grouped convolutions
|
||||
array wt_transpose(
|
||||
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
|
||||
wt_transpose.copy_shared_buffer(
|
||||
wt,
|
||||
{wt.strides(0), wt.strides(2), wt.strides(1)},
|
||||
wt.flags(),
|
||||
wt.size(),
|
||||
0);
|
||||
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
||||
copy(wt_transpose, gemm_wt, CopyType::General);
|
||||
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
@ -428,21 +454,22 @@ void explicit_gemm_conv_1D_cpu(
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
O_per_group, // N
|
||||
C_per_group * wH, // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>(),
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt.data<float>(),
|
||||
strided_reshape[1], // ldb
|
||||
in_strided.data<float>() + g * C_per_group * wH, // A
|
||||
wH * C, // lda
|
||||
gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
|
||||
wH * C_per_group, // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>(),
|
||||
gemm_out.data<float>() + g * O_per_group, // C
|
||||
O // ldc
|
||||
);
|
||||
|
||||
@ -450,6 +477,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_2D_cpu(
|
||||
|
@ -89,6 +89,90 @@ void explicit_gemm_conv_ND_gpu(
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
void explicit_gemm_conv_group_ND_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<N>& conv_params) {
|
||||
const int groups = conv_params.groups;
|
||||
const int C_per_group = conv_params.C / conv_params.groups;
|
||||
const int O_per_group = conv_params.O / conv_params.groups;
|
||||
// Get gemm shapes
|
||||
const int implicit_M = out.size() / conv_params.O;
|
||||
const int implicit_K = wt.size() / conv_params.O;
|
||||
const int implicit_N = O_per_group;
|
||||
|
||||
int kernel_size = 1;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
kernel_size *= conv_params.wS[i];
|
||||
}
|
||||
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
|
||||
<< N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Transpose kernel weights so that we can slice them by contiguous chunks
|
||||
// of channel groups.
|
||||
array wt_view(
|
||||
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
|
||||
wt_view.copy_shared_buffer(
|
||||
wt,
|
||||
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
|
||||
wt.flags(),
|
||||
wt.size());
|
||||
|
||||
// Materialize
|
||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
||||
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
|
||||
return steel_matmul_conv_groups(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt_transpose,
|
||||
/*c = */ out,
|
||||
/*M = */ implicit_M,
|
||||
/*N = */ implicit_N,
|
||||
/*K = */ implicit_K,
|
||||
/*a_cols = */ implicit_K * groups,
|
||||
/*b_cols = */ implicit_K,
|
||||
/*out_cols = */ implicit_N * groups,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/* groups = */ groups,
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@ -99,6 +183,7 @@ void conv_1D_gpu(
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
int groups,
|
||||
bool flip) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
@ -118,11 +203,15 @@ void conv_1D_gpu(
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
/* const int groups = */ 1,
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
if (groups > 1) {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
void slow_conv_2D_gpu(
|
||||
@ -721,6 +810,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_);
|
||||
}
|
||||
// Throw error
|
||||
|
@ -33,7 +33,7 @@ template <typename T, int N>
|
||||
// Set out
|
||||
out += gid.z * filter_size + gid.y * (params->C);
|
||||
|
||||
// Corrdinates in input
|
||||
// Coordinates in input
|
||||
int is[N] = {0};
|
||||
|
||||
// gid.z: N oS (Batch and row in unfolded output)
|
||||
@ -75,12 +75,81 @@ template <typename T, int N>
|
||||
} else {
|
||||
out[gid.x] = T(0);
|
||||
}
|
||||
}
|
||||
|
||||
// This kernel unfolds the input array of size (N, *spatial_dims, C)
|
||||
// into an array of size (N x *spatial_dims, C x *kernel_dims).
|
||||
template <typename T, int N>
|
||||
[[kernel]] void naive_unfold_transpose_Nd(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
|
||||
int filter_size = params->C;
|
||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
||||
|
||||
int out_pixels = 1;
|
||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
||||
|
||||
// Set out
|
||||
out += gid.z * filter_size + gid.x * (filter_size / params->C);
|
||||
|
||||
// Coordinates in input
|
||||
int is[N] = {0};
|
||||
|
||||
// gid.z: N oS (Batch and row in unfolded output)
|
||||
// gid.y: wS (Filter location to unfold input)
|
||||
// gid.x: C (channel)
|
||||
|
||||
int n = (gid.z) / out_pixels;
|
||||
int oS = (gid.z) % out_pixels;
|
||||
int wS = gid.y;
|
||||
|
||||
bool valid = n < params->N;
|
||||
|
||||
// Unroll dimensions
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
|
||||
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
|
||||
|
||||
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
|
||||
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
|
||||
|
||||
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
|
||||
|
||||
is[i] = is_ / params->idil[i];
|
||||
|
||||
oS /= params->oS[i];
|
||||
wS /= params->wS[i];
|
||||
|
||||
out += ws_ * params->str[i];
|
||||
}
|
||||
|
||||
if(valid) {
|
||||
size_t in_offset = n * params->in_strides[0];
|
||||
|
||||
for(int i = 0; i < N; ++i) {
|
||||
in_offset += is[i] * params->in_strides[i + 1];
|
||||
}
|
||||
|
||||
out[0] = in[in_offset + gid.x];
|
||||
} else {
|
||||
out[0] = T(0);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_naive_unfold_nd(name, itype, n) \
|
||||
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
|
||||
[[kernel]] void naive_unfold_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] \
|
||||
[[kernel]] void naive_unfold_transpose_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
|
@ -260,6 +260,110 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Steel matmul fallback
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void steel_matmul_conv_groups(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int groups,
|
||||
std::vector<array>& copies) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ ldd,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ K,
|
||||
/* const int batch_stride_b = */ N * K,
|
||||
/* const int batch_stride_d = */ N,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ 1};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, groups);
|
||||
|
||||
std::vector<int> batch_shape = {1};
|
||||
std::vector<size_t> batch_strides = {0};
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
void steel_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
|
@ -12,6 +12,23 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void steel_matmul_conv_groups(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int groups,
|
||||
std::vector<array>& copies);
|
||||
|
||||
void steel_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
|
38
mlx/ops.cpp
38
mlx/ops.cpp
@ -320,7 +320,7 @@ array reshape(
|
||||
"[reshape] Cannot infer the shape of an empty array");
|
||||
}
|
||||
|
||||
// Check the the reshaping is valid
|
||||
// Check that the reshaping is valid
|
||||
if (a.size() != size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[reshape] Cannot reshape array of size " << a.size()
|
||||
@ -2947,7 +2947,8 @@ inline std::vector<int> conv_out_shape(
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
|
||||
inline void
|
||||
run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
|
||||
if (!issubdtype(in.dtype(), floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input array with type " << in.dtype() << "."
|
||||
@ -2972,11 +2973,35 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (in.shape(n_dim + 1) != wt.shape(n_dim + 1)) {
|
||||
if (in.shape(n_dim + 1) % groups != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] The input channels must be divisible by the number"
|
||||
<< " of groups. Got input with shape " << in.shape() << " and "
|
||||
<< groups << " groups.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (groups > 1 && wt.shape(0) % groups != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] If groups > 1, the output channels must be divisible by the number"
|
||||
<< " of groups. Got " << wt.shape(0) << " output channels and "
|
||||
<< groups << " groups.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (in.shape(n_dim + 1) != (groups * wt.shape(n_dim + 1))) {
|
||||
std::ostringstream msg;
|
||||
if (groups == 1) {
|
||||
msg << "[conv] Expect the input channels in the input"
|
||||
<< " and weight array to match but got shapes -"
|
||||
<< " input: " << in.shape() << " and weight: " << wt.shape();
|
||||
|
||||
} else {
|
||||
msg << "Given groups=" << groups << " and weights of shape " << wt.shape()
|
||||
<< ", expected to have " << (groups * wt.shape(n_dim + 1))
|
||||
<< " input channels but got " << in.shape(n_dim + 1)
|
||||
<< " input channels instead.";
|
||||
}
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@ -3039,8 +3064,9 @@ array conv_general(
|
||||
bool flip /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Run checks
|
||||
if (groups != 1) {
|
||||
throw std::invalid_argument("[conv] Cannot handle groups != 1 yet");
|
||||
if (groups != 1 && in.ndim() != 3) {
|
||||
throw std::invalid_argument(
|
||||
"[conv] Can only handle groups != 1 in 1D convolutions.");
|
||||
}
|
||||
|
||||
int spatial_dims = in.ndim() - 2;
|
||||
@ -3052,7 +3078,7 @@ array conv_general(
|
||||
}
|
||||
|
||||
// Run checks
|
||||
run_conv_checks(in, wt, spatial_dims);
|
||||
run_conv_checks(in, wt, spatial_dims, groups);
|
||||
|
||||
// Type promotion
|
||||
auto out_type = promote_types(in.dtype(), wt.dtype());
|
||||
|
@ -831,6 +831,11 @@ std::vector<array> Convolution::vjp(
|
||||
assert(primals.size() == 2);
|
||||
std::vector<array> grads;
|
||||
|
||||
if (groups_ != 1) {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution] Backward pass not implemented for groups > 1.");
|
||||
}
|
||||
|
||||
// Collect info
|
||||
auto& in = primals[0];
|
||||
auto& wt = primals[1];
|
||||
|
@ -77,7 +77,9 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
@ -119,6 +121,12 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
):
|
||||
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
# Groups tests
|
||||
N, C, O = (4, 32, 64)
|
||||
iH, kH, stride, padding = (31, 5, 1, 2)
|
||||
for group in (1, 2, 4, 8, 16, 32):
|
||||
run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype)
|
||||
|
||||
# Strided inputs tests
|
||||
for tpose_in, tpose_wt in (
|
||||
((0, 2, 1), (0, 1, 2)),
|
||||
|
@ -3228,3 +3228,102 @@ TEST_CASE("test meshgrid") {
|
||||
CHECK(array_equal(out[0], expected_zero).item<bool>());
|
||||
CHECK(array_equal(out[1], expected_one).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test conv1d") {
|
||||
auto in = astype(
|
||||
array(
|
||||
{0.5488135,
|
||||
0.71518937,
|
||||
0.60276338,
|
||||
0.54488318,
|
||||
0.4236548,
|
||||
0.64589411},
|
||||
{1, 3, 2}),
|
||||
float16);
|
||||
|
||||
int kernel = 3;
|
||||
int stride = 1;
|
||||
int padding = 1;
|
||||
|
||||
{
|
||||
int groups = 1;
|
||||
auto wt = astype(
|
||||
array(
|
||||
{
|
||||
|
||||
0.43758721, 0.891773, 0.96366276, 0.38344152,
|
||||
0.79172504, 0.52889492,
|
||||
|
||||
0.56804456, 0.92559664, 0.07103606, 0.0871293,
|
||||
0.0202184, 0.83261985,
|
||||
|
||||
0.77815675, 0.87001215, 0.97861834, 0.79915856,
|
||||
0.46147936, 0.78052918,
|
||||
|
||||
0.11827443, 0.63992102, 0.14335329, 0.94466892,
|
||||
0.52184832, 0.41466194
|
||||
|
||||
},
|
||||
{4, 3, 2}),
|
||||
float16);
|
||||
|
||||
auto expected = array(
|
||||
{1.5685,
|
||||
0.5672,
|
||||
1.8121,
|
||||
1.2948,
|
||||
2.3448,
|
||||
1.6104,
|
||||
2.7743,
|
||||
1.6126,
|
||||
1.4056,
|
||||
0.9331,
|
||||
1.8739,
|
||||
1.0909},
|
||||
{1, 3, 4});
|
||||
|
||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
int groups = 2;
|
||||
auto wt = array(
|
||||
{0.43758721,
|
||||
0.891773,
|
||||
0.96366276,
|
||||
|
||||
0.38344152,
|
||||
0.79172504,
|
||||
0.52889492,
|
||||
|
||||
0.56804456,
|
||||
0.92559664,
|
||||
0.07103606,
|
||||
|
||||
0.0871293,
|
||||
0.0202184,
|
||||
0.83261985
|
||||
|
||||
},
|
||||
{4, 3, 1});
|
||||
|
||||
auto expected = array(
|
||||
{1.0703,
|
||||
0.7533,
|
||||
0.7007,
|
||||
0.4681,
|
||||
1.1859,
|
||||
0.9117,
|
||||
0.9565,
|
||||
0.6111,
|
||||
0.6416,
|
||||
0.5665,
|
||||
0.9074,
|
||||
0.0605},
|
||||
{1, 3, 4});
|
||||
|
||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user