mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
add test
This commit is contained in:
parent
76def90b73
commit
88cc8e0755
@ -944,8 +944,6 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
wt = arr_copy;
|
wt = arr_copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto padding_ = padding_lo_;
|
|
||||||
|
|
||||||
// 3D conv
|
// 3D conv
|
||||||
if (out.ndim() == 5) {
|
if (out.ndim() == 5) {
|
||||||
conv_3D_gpu(
|
conv_3D_gpu(
|
||||||
@ -954,7 +952,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@ -969,7 +967,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@ -985,7 +983,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
|
@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
atol=2e-5 if dtype == np.float32 else 5e-4,
|
atol=2e-5 if dtype == np.float32 else 5e-4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_asymmetric_padding(self):
|
||||||
|
inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32)
|
||||||
|
kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32)
|
||||||
|
strides = (2, 2, 2)
|
||||||
|
|
||||||
|
pt_out = torch.conv3d(
|
||||||
|
torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)),
|
||||||
|
torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)),
|
||||||
|
stride=strides,
|
||||||
|
padding=2,
|
||||||
|
)
|
||||||
|
pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy()
|
||||||
|
|
||||||
|
mx_out = mx.conv_general(
|
||||||
|
mx.array(inputs),
|
||||||
|
mx.array(kernel),
|
||||||
|
stride=strides,
|
||||||
|
padding=([0, 0, 0], [1, 1, 1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
|
||||||
|
|
||||||
|
inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32)
|
||||||
|
kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32)
|
||||||
|
|
||||||
|
pt_out = torch.conv2d(
|
||||||
|
torch.permute(torch.tensor(inputs), (0, 3, 1, 2)),
|
||||||
|
torch.permute(torch.tensor(kernel), (0, 3, 1, 2)),
|
||||||
|
stride=1,
|
||||||
|
padding=(1, 0),
|
||||||
|
)
|
||||||
|
pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy()
|
||||||
|
|
||||||
|
mx_out = mx.conv_general(
|
||||||
|
mx.array(inputs),
|
||||||
|
mx.array(kernel),
|
||||||
|
stride=1,
|
||||||
|
padding=([0, 0], [1, 0]),
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user