mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Restyled pooling
This commit is contained in:
@@ -451,12 +451,10 @@ class AdaptiveAvgPool1d(_AdaptivePool):
|
|||||||
kernel_L = L // output_L
|
kernel_L = L // output_L
|
||||||
|
|
||||||
if L % output_L == 0:
|
if L % output_L == 0:
|
||||||
# Efficient path for exact division
|
|
||||||
new_shape = batch_dims + [output_L, kernel_L, C]
|
new_shape = batch_dims + [output_L, kernel_L, C]
|
||||||
x_reshaped = x.reshape(new_shape)
|
x_reshaped = x.reshape(new_shape)
|
||||||
return mx.mean(x_reshaped, axis=-2)
|
return mx.mean(x_reshaped, axis=-2)
|
||||||
else:
|
else:
|
||||||
# Manual indexing for non-exact division
|
|
||||||
stride_L = (L - kernel_L) // (output_L - 1) if output_L > 1 else 1
|
stride_L = (L - kernel_L) // (output_L - 1) if output_L > 1 else 1
|
||||||
|
|
||||||
values = []
|
values = []
|
||||||
|
|||||||
Reference in New Issue
Block a user