Convolution update (#651)

* Init steel conv and update Conv primitive

* Update slow CPU implementation to support flipping and input dilation winograd conv routing

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jagrit Digani
2024-02-28 20:11:16 -08:00
committed by GitHub
parent f5f18b704f
commit 776c3d226d
27 changed files with 2830 additions and 906 deletions

View File

@@ -3081,7 +3081,7 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
2D convolution over an input with several channels
@@ -3105,6 +3105,114 @@ void init_ops(py::module_& m) {
array: The convolved array.
)pbdoc");
m.def(
"conv_general",
[](const array& input,
const array& weight,
const std::variant<int, std::vector<int>>& stride,
const std::variant<
int,
std::vector<int>,
std::pair<std::vector<int>, std::vector<int>>>& padding,
const std::variant<int, std::vector<int>>& kernel_dilation,
const std::variant<int, std::vector<int>>& input_dilation,
int groups,
bool flip,
StreamOrDevice s) {
std::vector<int> stride_vec;
std::vector<int> padding_lo_vec;
std::vector<int> padding_hi_vec;
std::vector<int> kernel_dilation_vec;
std::vector<int> input_dilation_vec;
if (auto pv = std::get_if<int>(&stride); pv) {
stride_vec.push_back(*pv);
} else {
stride_vec = std::get<std::vector<int>>(stride);
}
if (auto pv = std::get_if<int>(&padding); pv) {
padding_lo_vec.push_back(*pv);
padding_hi_vec.push_back(*pv);
} else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) {
padding_lo_vec = *pv;
padding_hi_vec = *pv;
} else {
auto [pl, ph] =
std::get<std::pair<std::vector<int>, std::vector<int>>>(padding);
padding_lo_vec = pl;
padding_hi_vec = ph;
}
if (auto pv = std::get_if<int>(&kernel_dilation); pv) {
kernel_dilation_vec.push_back(*pv);
} else {
kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation);
}
if (auto pv = std::get_if<int>(&input_dilation); pv) {
input_dilation_vec.push_back(*pv);
} else {
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
}
return conv_general(
/* const array& input = */ input,
/* const array& weight = */ weight,
/* std::vector<int> stride = */ stride_vec,
/* std::vector<int> padding_lo = */ padding_lo_vec,
/* std::vector<int> padding_hi = */ padding_lo_vec,
/* std::vector<int> kernel_dilation = */ kernel_dilation_vec,
/* std::vector<int> input_dilation = */ input_dilation_vec,
/* int groups = */ groups,
/* bool flip = */ flip,
s);
},
"input"_a,
"weight"_a,
py::pos_only(),
"stride"_a = 1,
"padding"_a = 0,
"kernel_dilation"_a = 1,
"input_dilation"_a = 1,
"groups"_a = 1,
"flip"_a = false,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array
General convolution over an input with several channels
.. note::
* Only 1d and 2d convolutions are supported at the moment
* the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, ..., C_in)``
weight (array): Weight array of shape ``(C_out, ..., C_in)``
stride (int or list(int), optional): :obj:`list` with kernel strides.
All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
padding (int, list(int), or tuple(list(int), list(int)), optional):
:obj:`list` with input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``.
kernel_dilation (int or list(int), optional): :obj:`list` with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
input_dilation (int or list(int), optional): :obj:`list` with
input dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
groups (int, optional): Input feature groups. Default: ``1``.
flip (bool, optional): Flip the order in which the spatial dimensions of
the weights are processed. Performs the cross-correlation operator when
``flip`` is ``False`` and the convolution operator otherwise.
Default: ``False``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"save",
&mlx_save_helper,
"file"_a,

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
import unittest
@@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase):
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
[in_mx, wt_mx],
[ct_mx],
)
pt_grad_in = F.grad.conv1d_input(
in_pt.shape,
@@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
for idim, kdim, stride, padding, dilation in (
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
):
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
run_conv2D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
def __conv_general_test(
self,
in_shape,
wt_shape,
stride=1,
padding=0,
kernel_dilation=1,
input_dilation=1,
groups=1,
flip=False,
np_dtype=np.float32,
atol=1e-5,
):
with self.subTest(
in_shape=in_shape,
wt_shape=wt_shape,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
np_dtype=np_dtype,
):
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"),
(in_np, wt_np),
)
out_mx = mx.conv_general(
in_mx,
wt_mx,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
)
def conv_general_pt(
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
):
C = inp.size()[1]
ndim = inp.ndim - 2
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
stride, padding, kernel_dilation, input_dilation = map(
map_ints, (stride, padding, kernel_dilation, input_dilation)
)
torch_convt_list = (
F.conv_transpose1d,
F.conv_transpose2d,
F.conv_transpose3d,
)
torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)
conv_f = torch_conv_list[ndim - 1]
convt_f = torch_convt_list[ndim - 1]
if flip:
wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))
if not np.all(input_dilation == 1):
ones = torch.ones(
[C]
+ [
1,
]
* (ndim + 1)
).to(inp.dtype)
inp = convt_f(inp, ones, stride=input_dilation, groups=C)
return conv_f(
inp,
wt,
stride=stride,
padding=padding,
dilation=kernel_dilation,
groups=groups,
)
out_pt = conv_general_pt(
in_pt,
wt_pt,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
)
out_pt = np.moveaxis(out_pt.numpy(), 1, -1)
self.assertEqual(out_mx.shape, out_pt.shape)
self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_general(self):
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 5, 16)
stride = (1, 1)
padding = (2, 2)
kernel_dilation = (2, 3)
input_dilation = (1, 1)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 3)
padding = (0, 0)
kernel_dilation = (3, 2)
input_dilation = (2, 4)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 2)
padding = (3, 2)
kernel_dilation = (3, 2)
input_dilation = (2, 4)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 3)
padding = (3, 2)
kernel_dilation = (3, 2)
input_dilation = (2, 5)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 5, 16)
stride = (2, 3)
padding = (0, 0)
kernel_dilation = (3, 1)
input_dilation = (2, 5)
flip = True
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
if __name__ == "__main__":