Restyled pooling

This commit is contained in:
Vincent Amato
2025-08-11 23:15:38 -04:00
parent 652a143b64
commit c59b46a488

View File

@@ -451,12 +451,10 @@ class AdaptiveAvgPool1d(_AdaptivePool):
kernel_L = L // output_L
if L % output_L == 0:
# Efficient path for exact division
new_shape = batch_dims + [output_L, kernel_L, C]
x_reshaped = x.reshape(new_shape)
return mx.mean(x_reshaped, axis=-2)
else:
# Manual indexing for non-exact division
stride_L = (L - kernel_L) // (output_L - 1) if output_L > 1 else 1
values = []