Conv grad with groups + bugfix (#1449)

* fix bug in flipped conv with groups, start of grad for groups

* fix

* fix

* fix + test
This commit is contained in:
Awni Hannun
2024-10-06 07:08:53 -07:00
committed by GitHub
parent fef3c4ec1d
commit e4534dac17
6 changed files with 197 additions and 176 deletions

View File

@@ -72,7 +72,7 @@ void explicit_gemm_conv_ND_gpu(
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_reshaped};
std::vector<array> copies = {in_unfolded};
return steel_matmul(
s,
d,
@@ -155,22 +155,27 @@ void explicit_gemm_conv_group_ND_gpu(
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
return steel_matmul_conv_groups(
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_transpose,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*a_cols = */ implicit_K * groups,
/*b_cols = */ implicit_K,
/*out_cols = */ implicit_N * groups,
/*a_transposed = */ false,
/*b_transposed = */ true,
/* groups = */ groups,
/* a = */ in_unfolded,
/* b = */ wt_transpose,
/* c = */ out,
/* M = */ implicit_M,
/* N = */ implicit_N,
/* K = */ implicit_K,
/* batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups,
/* b_cols = */ implicit_K,
/* out_cols = */ implicit_N * groups,
/* a_transposed = */ false,
/* b_transposed = */ true,
/* batch_shape = */ {1},
/* batch_strides = */ {0},
/* A_batch_strides = */ size_t(implicit_K),
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
/* matrix_stride_out = */ size_t(implicit_N),
/*copies = */ copies);
}

View File

@@ -113,6 +113,7 @@ template <typename T, int N>
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
out += ws_ * kernel_stride;
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
@@ -126,7 +127,6 @@ template <typename T, int N>
oS /= params->oS[i];
wS /= params->wS[i];
out += ws_ * kernel_stride;
kernel_stride *= params->wS[i];
}

View File

@@ -88,7 +88,7 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Steel matmul fallback
///////////////////////////////////////////////////////////////////////////////
void steel_matmul_conv_groups(
void steel_matmul_regular(
const Stream& s,
metal::Device& d,
const array& a,
@@ -97,23 +97,25 @@ void steel_matmul_conv_groups(
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
int groups,
std::vector<int> batch_shape,
std::vector<size_t> batch_strides,
size_t A_batch_stride,
size_t B_batch_stride,
size_t matrix_stride_out,
std::vector<array>& copies) {
using namespace mlx::steel;
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)M * N >= 1ul << 20) {
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
@@ -133,7 +135,7 @@ void steel_matmul_conv_groups(
std::string base_name = kname.str();
const bool has_batch = false;
const bool has_batch = (batch_shape.size() > 1);
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
@@ -197,12 +199,12 @@ void steel_matmul_conv_groups(
/* const int ldd = */ ldd,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ size_t(K),
/* const size_t batch_stride_b = */ size_t(N) * K,
/* const size_t batch_stride_d = */ size_t(N),
/* const size_t batch_stride_a = */ A_batch_stride,
/* const size_t batch_stride_b = */ B_batch_stride,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 1};
/* const int batch_ndim = */ int(batch_shape.size())};
// Prepare launch grid params
int tile = 1 << swizzle_log;
@@ -210,15 +212,13 @@ void steel_matmul_conv_groups(
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, groups);
std::vector<int> batch_shape = {1};
std::vector<size_t> batch_strides = {0};
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
@@ -393,133 +393,31 @@ void steel_matmul(
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = (batch_shape.size() > 1);
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ A_batch_stride.back(),
/* const size_t batch_stride_b = */ B_batch_stride.back(),
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
std::vector<size_t> batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
steel_matmul_regular(
s,
d,
a,
b,
out,
M,
N,
K,
batch_size_out,
lda,
ldb,
N,
transpose_a,
transpose_b,
std::move(batch_shape),
std::move(batch_strides),
A_batch_stride.back(),
B_batch_stride.back(),
matrix_stride_out,
copies);
}
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -4,7 +4,7 @@
namespace mlx::core {
void steel_matmul_conv_groups(
void steel_matmul_regular(
const Stream& s,
metal::Device& d,
const array& a,
@@ -13,12 +13,17 @@ void steel_matmul_conv_groups(
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
int groups,
std::vector<int> batch_shape,
std::vector<size_t> batch_strides,
size_t A_batch_stride,
size_t B_batch_stride,
size_t matrix_stride_out,
std::vector<array>& copies);
void steel_matmul(