Added Symmetric and reflect modes to pad

This commit is contained in:
paramthakkar123
2025-05-07 08:15:35 +05:30
parent 5a1a5d5ed1
commit 4caeb05c64
3 changed files with 151 additions and 20 deletions

View File

@@ -3102,27 +3102,29 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'symmetric', 'reflect'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Pad an array with a constant value
Args:
a (array): Input array.
pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded
values to add to the edges of each axis:``((before_1, after_1),
(before_2, after_2), ..., (before_N, after_N))``. If a single pair
of integers is passed then ``(before_i, after_i)`` are all the same.
If a single integer or tuple with a single integer is passed then
all axes are extended by the same number on each side.
mode: Padding mode. One of the following strings:
"constant" (default): Pads with a constant value.
"edge": Pads with the edge values of array.
constant_value (array or scalar, optional): Optional constant value
to pad the edges of the array with.
Returns:
array: The padded array.
)pbdoc");
Pad an array with a constant value or using other modes.
Args:
a (array): Input array.
pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded
values to add to the edges of each axis:``((before_1, after_1),
(before_2, after_2), ..., (before_N, after_N))``. If a single pair
of integers is passed then ``(before_i, after_i)`` are all the same.
If a single integer or tuple with a single integer is passed then
all axes are extended by the same number on each side.
mode: Padding mode. One of the following strings:
"constant" (default): Pads with a constant value.
"edge": Pads with the edge values of array.
"symmetric": Pads with the reflection of the array, including the edge values.
"reflect": Pads with the reflection of the array, excluding the edge values.
constant_value (array or scalar, optional): Optional constant value
to pad the edges of the array with.
Returns:
array: The padded array.
)pbdoc");
m.def(
"as_strided",
[](const mx::array& a,