mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix conv grad (#2187)
This commit is contained in:
parent
a2cadb8218
commit
602f43e3d1
@ -1116,13 +1116,11 @@ array conv_weight_backward_patches(
|
|||||||
// Pad input
|
// Pad input
|
||||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||||
Shape padding_lo_(padding_lo.begin(), padding_lo.end());
|
|
||||||
Shape padding_hi_(padding_hi.begin(), padding_hi.end());
|
|
||||||
auto in_padded =
|
auto in_padded =
|
||||||
pad(in,
|
pad(in,
|
||||||
padded_axes,
|
padded_axes,
|
||||||
padding_lo_,
|
Shape(padding_lo),
|
||||||
padding_hi_,
|
Shape(padding_hi),
|
||||||
array(0, in.dtype()),
|
array(0, in.dtype()),
|
||||||
"constant",
|
"constant",
|
||||||
s);
|
s);
|
||||||
@ -1274,8 +1272,14 @@ std::vector<array> Convolution::vjp(
|
|||||||
in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());
|
in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());
|
||||||
grads.push_back(grad);
|
grads.push_back(grad);
|
||||||
} else {
|
} else {
|
||||||
std::vector<int> padding_lo = padding_lo_;
|
auto padding_hi = padding_lo_;
|
||||||
std::vector<int> padding_hi = padding_hi_;
|
|
||||||
|
for (int i = 0; i < padding_hi.size(); ++i) {
|
||||||
|
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||||
|
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||||
|
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||||
|
padding_hi[i] = out_size - in_size + wt_size - padding_hi[i] - 1;
|
||||||
|
}
|
||||||
|
|
||||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||||
auto in_trans = group_transpose(in, -1, 0, -1);
|
auto in_trans = group_transpose(in, -1, 0, -1);
|
||||||
@ -1284,7 +1288,7 @@ std::vector<array> Convolution::vjp(
|
|||||||
/* const array& input = */ in_trans,
|
/* const array& input = */ in_trans,
|
||||||
/* const array& weight = */ cotan_trans,
|
/* const array& weight = */ cotan_trans,
|
||||||
/* std::vector<int> stride = */ kernel_dilation_,
|
/* std::vector<int> stride = */ kernel_dilation_,
|
||||||
/* std::vector<int> padding_lo = */ padding_lo,
|
/* std::vector<int> padding_lo = */ padding_lo_,
|
||||||
/* std::vector<int> padding_hi = */ padding_hi,
|
/* std::vector<int> padding_hi = */ padding_hi,
|
||||||
/* std::vector<int> kernel_dilation = */ kernel_strides_,
|
/* std::vector<int> kernel_dilation = */ kernel_strides_,
|
||||||
/* std::vector<int> input_dilation = */ input_dilation_,
|
/* std::vector<int> input_dilation = */ input_dilation_,
|
||||||
|
@ -1130,6 +1130,28 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
|
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
|
||||||
|
|
||||||
|
def test_basic_grad_shapes(self):
|
||||||
|
def loss_fn(kernel, inputs, strides, groups):
|
||||||
|
return mx.sum(
|
||||||
|
mx.conv_general(
|
||||||
|
inputs,
|
||||||
|
kernel,
|
||||||
|
stride=strides,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for in_shape, k_shape, strides, groups in [
|
||||||
|
((3, 5, 4), (6, 2, 2), (2,), 2),
|
||||||
|
((3, 5, 4), (24, 2, 1), (2,), 4),
|
||||||
|
((3, 5, 5, 4), (6, 2, 2, 2), (2, 1), 2),
|
||||||
|
((3, 5, 5, 4), (24, 2, 2, 1), (2, 2), 4),
|
||||||
|
]:
|
||||||
|
grads = mx.grad(loss_fn)(
|
||||||
|
mx.zeros(k_shape), mx.zeros(in_shape), strides, groups
|
||||||
|
)
|
||||||
|
self.assertEqual(grads.shape, k_shape)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user