mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 21:21:21 +08:00
Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups * fix * fix * fix + test
This commit is contained in:
parent
fef3c4ec1d
commit
e4534dac17
@ -72,7 +72,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_reshaped};
|
||||
std::vector<array> copies = {in_unfolded};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
@ -155,22 +155,27 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
|
||||
return steel_matmul_conv_groups(
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
return steel_matmul_regular(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt_transpose,
|
||||
/*c = */ out,
|
||||
/*M = */ implicit_M,
|
||||
/*N = */ implicit_N,
|
||||
/*K = */ implicit_K,
|
||||
/*a_cols = */ implicit_K * groups,
|
||||
/*b_cols = */ implicit_K,
|
||||
/*out_cols = */ implicit_N * groups,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/* groups = */ groups,
|
||||
/* a = */ in_unfolded,
|
||||
/* b = */ wt_transpose,
|
||||
/* c = */ out,
|
||||
/* M = */ implicit_M,
|
||||
/* N = */ implicit_N,
|
||||
/* K = */ implicit_K,
|
||||
/* batch_size_out = */ groups,
|
||||
/* a_cols = */ implicit_K * groups,
|
||||
/* b_cols = */ implicit_K,
|
||||
/* out_cols = */ implicit_N * groups,
|
||||
/* a_transposed = */ false,
|
||||
/* b_transposed = */ true,
|
||||
/* batch_shape = */ {1},
|
||||
/* batch_strides = */ {0},
|
||||
/* A_batch_strides = */ size_t(implicit_K),
|
||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
||||
/* matrix_stride_out = */ size_t(implicit_N),
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
|
@ -113,6 +113,7 @@ template <typename T, int N>
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
out += ws_ * kernel_stride;
|
||||
|
||||
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
|
||||
|
||||
@ -126,7 +127,6 @@ template <typename T, int N>
|
||||
oS /= params->oS[i];
|
||||
wS /= params->wS[i];
|
||||
|
||||
out += ws_ * kernel_stride;
|
||||
kernel_stride *= params->wS[i];
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,7 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Steel matmul fallback
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void steel_matmul_conv_groups(
|
||||
void steel_matmul_regular(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@ -97,23 +97,25 @@ void steel_matmul_conv_groups(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int groups,
|
||||
std::vector<int> batch_shape,
|
||||
std::vector<size_t> batch_strides,
|
||||
size_t A_batch_stride,
|
||||
size_t B_batch_stride,
|
||||
size_t matrix_stride_out,
|
||||
std::vector<array>& copies) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)M * N >= 1ul << 20) {
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
@ -133,7 +135,7 @@ void steel_matmul_conv_groups(
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = false;
|
||||
const bool has_batch = (batch_shape.size() > 1);
|
||||
const bool use_out_source = false;
|
||||
const bool do_axpby = false;
|
||||
const bool align_M = (M % bm) == 0;
|
||||
@ -197,12 +199,12 @@ void steel_matmul_conv_groups(
|
||||
/* const int ldd = */ ldd,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ size_t(K),
|
||||
/* const size_t batch_stride_b = */ size_t(N) * K,
|
||||
/* const size_t batch_stride_d = */ size_t(N),
|
||||
/* const size_t batch_stride_a = */ A_batch_stride,
|
||||
/* const size_t batch_stride_b = */ B_batch_stride,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ 1};
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
@ -210,15 +212,13 @@ void steel_matmul_conv_groups(
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, groups);
|
||||
|
||||
std::vector<int> batch_shape = {1};
|
||||
std::vector<size_t> batch_strides = {0};
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
@ -393,133 +393,31 @@ void steel_matmul(
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn;
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = (batch_shape.size() > 1);
|
||||
const bool use_out_source = false;
|
||||
const bool do_axpby = false;
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = false;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const size_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
steel_matmul_regular(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_size_out,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
std::move(batch_shape),
|
||||
std::move(batch_strides),
|
||||
A_batch_stride.back(),
|
||||
B_batch_stride.back(),
|
||||
matrix_stride_out,
|
||||
copies);
|
||||
}
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void steel_matmul_conv_groups(
|
||||
void steel_matmul_regular(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@ -13,12 +13,17 @@ void steel_matmul_conv_groups(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int groups,
|
||||
std::vector<int> batch_shape,
|
||||
std::vector<size_t> batch_strides,
|
||||
size_t A_batch_stride,
|
||||
size_t B_batch_stride,
|
||||
size_t matrix_stride_out,
|
||||
std::vector<array>& copies);
|
||||
|
||||
void steel_matmul(
|
||||
|
@ -929,16 +929,28 @@ std::vector<array> Convolution::vjp(
|
||||
assert(primals.size() == 2);
|
||||
std::vector<array> grads;
|
||||
|
||||
if (groups_ != 1) {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution] Backward pass not implemented for groups > 1.");
|
||||
}
|
||||
|
||||
// Collect info
|
||||
auto& in = primals[0];
|
||||
auto& wt = primals[1];
|
||||
auto& cotan = cotangents[0];
|
||||
|
||||
auto group_transpose =
|
||||
[this](const array& x, int group_dim, int ax_a, int ax_b) {
|
||||
if (groups_ > 1) {
|
||||
auto shape = x.shape();
|
||||
if (group_dim < 0) {
|
||||
group_dim += shape.size();
|
||||
}
|
||||
shape.insert(shape.begin() + group_dim, groups_);
|
||||
shape[group_dim + 1] = shape[group_dim + 1] / groups_;
|
||||
auto x_trans = swapaxes(
|
||||
reshape(x, std::move(shape), stream()), ax_a, ax_b, stream());
|
||||
return flatten(x_trans, group_dim, group_dim + 1, stream());
|
||||
} else {
|
||||
return swapaxes(x, 0, -1, stream());
|
||||
}
|
||||
};
|
||||
|
||||
for (int a : argnums) {
|
||||
// Grads for input
|
||||
if (a == 0) {
|
||||
@ -976,8 +988,7 @@ std::vector<array> Convolution::vjp(
|
||||
}
|
||||
}
|
||||
|
||||
auto wt_trans = swapaxes(wt, 0, -1, stream());
|
||||
|
||||
auto wt_trans = group_transpose(wt, 0, 1, -1);
|
||||
auto grad = conv_general(
|
||||
/* const array& input = */ cotan,
|
||||
/* const array& weight = */ wt_trans,
|
||||
@ -986,7 +997,7 @@ std::vector<array> Convolution::vjp(
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
|
||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||
/* int groups = */ 1,
|
||||
/* int groups = */ groups_,
|
||||
/* bool flip = */ !flip_,
|
||||
stream());
|
||||
|
||||
@ -1020,14 +1031,11 @@ std::vector<array> Convolution::vjp(
|
||||
no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);
|
||||
}
|
||||
|
||||
if (no_dilation && !flip_) {
|
||||
if (no_dilation && !flip_ && groups_ == 1) {
|
||||
auto grad = conv_weight_backward_patches(
|
||||
in, wt, cotan, kernel_strides_, padding_, stream());
|
||||
grads.push_back(grad);
|
||||
} else {
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
|
||||
if (flip_) {
|
||||
auto padding = padding_;
|
||||
for (int i = 0; i < padding.size(); i++) {
|
||||
@ -1035,6 +1043,9 @@ std::vector<array> Convolution::vjp(
|
||||
padding[i] = wt_size - padding_[i] - 1;
|
||||
}
|
||||
|
||||
auto cotan_trans = group_transpose(cotan, -1, 0, -1);
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ cotan_trans,
|
||||
/* const array& weight = */ in_trans,
|
||||
@ -1043,11 +1054,14 @@ std::vector<array> Convolution::vjp(
|
||||
/* std::vector<int> padding_hi = */ padding,
|
||||
/* std::vector<int> kernel_dilation = */ input_dilation_,
|
||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||
/* int groups = */ 1,
|
||||
/* int groups = */ groups_,
|
||||
/* bool flip = */ false,
|
||||
stream());
|
||||
auto grad = swapaxes(grad_trans, 0, -1, stream());
|
||||
grads.push_back(grad_trans);
|
||||
if (groups_ > 1) {
|
||||
grads.push_back(group_transpose(grad_trans, -1, 0, -2));
|
||||
} else {
|
||||
grads.push_back(grad_trans);
|
||||
}
|
||||
} else {
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
@ -1058,9 +1072,9 @@ std::vector<array> Convolution::vjp(
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
|
||||
}
|
||||
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto in_trans = group_transpose(in, -1, 0, -1);
|
||||
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ in_trans,
|
||||
/* const array& weight = */ cotan_trans,
|
||||
@ -1069,11 +1083,10 @@ std::vector<array> Convolution::vjp(
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_strides_,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_,
|
||||
/* int groups = */ 1,
|
||||
/* int groups = */ groups_,
|
||||
/* bool flip = */ false,
|
||||
stream());
|
||||
auto grad = swapaxes(grad_trans, 0, -1, stream());
|
||||
grads.push_back(grad);
|
||||
grads.push_back(swapaxes(grad_trans, 0, -1, stream()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -47,6 +47,13 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(c_mx.shape, c_np.shape)
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
|
||||
|
||||
def test_conv_1d_groups_flipped(self):
|
||||
x = mx.broadcast_to(mx.arange(5).astype(mx.float32), (2, 5)).T
|
||||
w = mx.broadcast_to(mx.arange(4).astype(mx.float32), (2, 4))
|
||||
out = mx.conv_general(x[None], w[..., None], flip=True, groups=2)
|
||||
expected = mx.array([4.0, 4.0, 10.0, 10.0]).reshape(1, 2, 2)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_1D(self):
|
||||
def run_conv1D(
|
||||
@ -897,6 +904,99 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
||||
self.assertTrue(mx.allclose(dw, expected))
|
||||
|
||||
def test_conv_groups_grad(self):
|
||||
def fn(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
return mx.conv1d(x, w, groups=num_groups)
|
||||
|
||||
def fn_gt(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
group_size = w.shape[-1]
|
||||
ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)
|
||||
xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)
|
||||
return mx.concatenate(
|
||||
[mx.conv_general(x.squeeze(-2), w.squeeze(0)) for x, w in zip(xs, ws)],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
mx.random.seed(3)
|
||||
|
||||
w = mx.random.normal(shape=(2, 3, 1))
|
||||
x = mx.random.normal(shape=(1, 5, 2))
|
||||
cotans = (mx.ones(shape=(1, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
w = mx.random.normal(shape=(2, 3, 2))
|
||||
x = mx.random.normal(shape=(1, 5, 4))
|
||||
cotans = (mx.ones(shape=(1, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
w = mx.random.normal(shape=(6, 3, 2))
|
||||
x = mx.random.normal(shape=(1, 5, 4))
|
||||
cotans = (mx.ones(shape=(1, 3, 6)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
# Test 2D
|
||||
w = mx.random.normal(shape=(2, 3, 3, 1))
|
||||
x = mx.random.normal(shape=(1, 5, 5, 2))
|
||||
cotans = (mx.ones(shape=(1, 3, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
# Test with flip
|
||||
def fn(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
return mx.conv_general(x, w, groups=num_groups, flip=True)
|
||||
|
||||
def fn_gt(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
group_size = w.shape[-1]
|
||||
ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)
|
||||
xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)
|
||||
return mx.concatenate(
|
||||
[
|
||||
mx.conv_general(x.squeeze(-2), w.squeeze(0), flip=True)
|
||||
for x, w in zip(xs, ws)
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
w = mx.random.normal(shape=(2, 3, 1))
|
||||
x = mx.random.normal(shape=(1, 5, 2))
|
||||
cotans = (mx.ones(shape=(1, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
w = mx.random.normal(shape=(2, 3, 2))
|
||||
x = mx.random.normal(shape=(1, 5, 4))
|
||||
cotans = (mx.ones(shape=(1, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
# Test 2D
|
||||
w = mx.random.normal(shape=(2, 3, 3, 1))
|
||||
x = mx.random.normal(shape=(1, 5, 5, 2))
|
||||
cotans = (mx.ones(shape=(1, 3, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user