Add "edge" mode to mx.pad (#1309)

* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
This commit is contained in:
Alex Barron
2024-08-06 11:23:10 -07:00
committed by GitHub
parent 8c9f0278b9
commit 635ccd9e25
6 changed files with 102 additions and 18 deletions

View File

@@ -855,8 +855,8 @@ array conv_weight_backward_patches(
// Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1);
auto in_padded =
pad(in, padded_axes, padding, padding, array(0, in.dtype()), s);
auto in_padded = pad(
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s);
// Resolve strided patches
@@ -2289,6 +2289,7 @@ std::vector<array> Pad::jvp(
low_pad_size_,
high_pad_size_,
array(0, tangents[0].dtype()),
"constant",
stream())};
}
@@ -3740,7 +3741,8 @@ std::vector<array> BlockMaskedMM::vjp(
// Pad if needed
if ((align_Y != 0) || (align_X != 0)) {
r = pad(r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, stream());
r = pad(
r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, "constant", stream());
}
// Reshape