mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Restyled pooling
This commit is contained in:
@@ -451,12 +451,10 @@ class AdaptiveAvgPool1d(_AdaptivePool):
|
||||
kernel_L = L // output_L
|
||||
|
||||
if L % output_L == 0:
|
||||
# Efficient path for exact division
|
||||
new_shape = batch_dims + [output_L, kernel_L, C]
|
||||
x_reshaped = x.reshape(new_shape)
|
||||
return mx.mean(x_reshaped, axis=-2)
|
||||
else:
|
||||
# Manual indexing for non-exact division
|
||||
stride_L = (L - kernel_L) // (output_L - 1) if output_L > 1 else 1
|
||||
|
||||
values = []
|
||||
|
Reference in New Issue
Block a user