mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add "edge" mode to mx.pad (#1309)
* Add edge padding mode * fix pad in pooling * string arg instead of enum
This commit is contained in:
@@ -855,8 +855,8 @@ array conv_weight_backward_patches(
|
||||
// Pad input
|
||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||
auto in_padded =
|
||||
pad(in, padded_axes, padding, padding, array(0, in.dtype()), s);
|
||||
auto in_padded = pad(
|
||||
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s);
|
||||
|
||||
// Resolve strided patches
|
||||
|
||||
@@ -2289,6 +2289,7 @@ std::vector<array> Pad::jvp(
|
||||
low_pad_size_,
|
||||
high_pad_size_,
|
||||
array(0, tangents[0].dtype()),
|
||||
"constant",
|
||||
stream())};
|
||||
}
|
||||
|
||||
@@ -3740,7 +3741,8 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
|
||||
// Pad if needed
|
||||
if ((align_Y != 0) || (align_X != 0)) {
|
||||
r = pad(r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, stream());
|
||||
r = pad(
|
||||
r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, "constant", stream());
|
||||
}
|
||||
|
||||
// Reshape
|
||||
|
||||
Reference in New Issue
Block a user