mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Convolution update (#651)
* Init steel conv and update Conv primitive * Update slow CPU implementation to support flipping and input dilation winograd conv routing Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -3081,7 +3081,7 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
2D convolution over an input with several channels
|
||||
|
||||
@@ -3105,6 +3105,114 @@ void init_ops(py::module_& m) {
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_general",
|
||||
[](const array& input,
|
||||
const array& weight,
|
||||
const std::variant<int, std::vector<int>>& stride,
|
||||
const std::variant<
|
||||
int,
|
||||
std::vector<int>,
|
||||
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
||||
const std::variant<int, std::vector<int>>& kernel_dilation,
|
||||
const std::variant<int, std::vector<int>>& input_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> stride_vec;
|
||||
std::vector<int> padding_lo_vec;
|
||||
std::vector<int> padding_hi_vec;
|
||||
std::vector<int> kernel_dilation_vec;
|
||||
std::vector<int> input_dilation_vec;
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_vec.push_back(*pv);
|
||||
} else {
|
||||
stride_vec = std::get<std::vector<int>>(stride);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&padding); pv) {
|
||||
padding_lo_vec.push_back(*pv);
|
||||
padding_hi_vec.push_back(*pv);
|
||||
} else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) {
|
||||
padding_lo_vec = *pv;
|
||||
padding_hi_vec = *pv;
|
||||
} else {
|
||||
auto [pl, ph] =
|
||||
std::get<std::pair<std::vector<int>, std::vector<int>>>(padding);
|
||||
padding_lo_vec = pl;
|
||||
padding_hi_vec = ph;
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&kernel_dilation); pv) {
|
||||
kernel_dilation_vec.push_back(*pv);
|
||||
} else {
|
||||
kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&input_dilation); pv) {
|
||||
input_dilation_vec.push_back(*pv);
|
||||
} else {
|
||||
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
||||
}
|
||||
|
||||
return conv_general(
|
||||
/* const array& input = */ input,
|
||||
/* const array& weight = */ weight,
|
||||
/* std::vector<int> stride = */ stride_vec,
|
||||
/* std::vector<int> padding_lo = */ padding_lo_vec,
|
||||
/* std::vector<int> padding_hi = */ padding_lo_vec,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_dilation_vec,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_vec,
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ flip,
|
||||
s);
|
||||
},
|
||||
"input"_a,
|
||||
"weight"_a,
|
||||
py::pos_only(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"kernel_dilation"_a = 1,
|
||||
"input_dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
"flip"_a = false,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
General convolution over an input with several channels
|
||||
|
||||
.. note::
|
||||
|
||||
* Only 1d and 2d convolutions are supported at the moment
|
||||
* the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, ..., C_in)``
|
||||
weight (array): Weight array of shape ``(C_out, ..., C_in)``
|
||||
stride (int or list(int), optional): :obj:`list` with kernel strides.
|
||||
All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
padding (int, list(int), or tuple(list(int), list(int)), optional):
|
||||
:obj:`list` with input padding. All spatial dimensions get the same
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
kernel_dilation (int or list(int), optional): :obj:`list` with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
input_dilation (int or list(int), optional): :obj:`list` with
|
||||
input dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
flip (bool, optional): Flip the order in which the spatial dimensions of
|
||||
the weights are processed. Performs the cross-correlation operator when
|
||||
``flip`` is ``False`` and the convolution operator otherwise.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"save",
|
||||
&mlx_save_helper,
|
||||
"file"_a,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
@@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
@@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
|
||||
):
|
||||
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
run_conv2D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
def __conv_general_test(
|
||||
self,
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride=1,
|
||||
padding=0,
|
||||
kernel_dilation=1,
|
||||
input_dilation=1,
|
||||
groups=1,
|
||||
flip=False,
|
||||
np_dtype=np.float32,
|
||||
atol=1e-5,
|
||||
):
|
||||
|
||||
with self.subTest(
|
||||
in_shape=in_shape,
|
||||
wt_shape=wt_shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
np_dtype=np_dtype,
|
||||
):
|
||||
|
||||
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
||||
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
out_mx = mx.conv_general(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
def conv_general_pt(
|
||||
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
|
||||
):
|
||||
|
||||
C = inp.size()[1]
|
||||
ndim = inp.ndim - 2
|
||||
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
|
||||
|
||||
stride, padding, kernel_dilation, input_dilation = map(
|
||||
map_ints, (stride, padding, kernel_dilation, input_dilation)
|
||||
)
|
||||
|
||||
torch_convt_list = (
|
||||
F.conv_transpose1d,
|
||||
F.conv_transpose2d,
|
||||
F.conv_transpose3d,
|
||||
)
|
||||
torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)
|
||||
|
||||
conv_f = torch_conv_list[ndim - 1]
|
||||
convt_f = torch_convt_list[ndim - 1]
|
||||
|
||||
if flip:
|
||||
wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))
|
||||
|
||||
if not np.all(input_dilation == 1):
|
||||
ones = torch.ones(
|
||||
[C]
|
||||
+ [
|
||||
1,
|
||||
]
|
||||
* (ndim + 1)
|
||||
).to(inp.dtype)
|
||||
inp = convt_f(inp, ones, stride=input_dilation, groups=C)
|
||||
|
||||
return conv_f(
|
||||
inp,
|
||||
wt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=kernel_dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
out_pt = conv_general_pt(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
out_pt = np.moveaxis(out_pt.numpy(), 1, -1)
|
||||
|
||||
self.assertEqual(out_mx.shape, out_pt.shape)
|
||||
self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_general(self):
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 5, 16)
|
||||
stride = (1, 1)
|
||||
padding = (2, 2)
|
||||
kernel_dilation = (2, 3)
|
||||
input_dilation = (1, 1)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 3)
|
||||
padding = (0, 0)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 4)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 2)
|
||||
padding = (3, 2)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 4)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 3)
|
||||
padding = (3, 2)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 5)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 5, 16)
|
||||
stride = (2, 3)
|
||||
padding = (0, 0)
|
||||
kernel_dilation = (3, 1)
|
||||
input_dilation = (2, 5)
|
||||
flip = True
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user