Update smooth_l1_loss in losses.py (#1974)

According the definition of smooth_l1_loss, the line 

diff = predictions - targets

Should be updated to 

diff = mx.abs(predictions - targets)

After the modification, the result is consistent with PyTorch smooth_l1_loss
This commit is contained in:
jiyzhang 2025-03-20 11:19:02 +08:00 committed by GitHub
parent f90206ad74
commit 95e335db7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -373,7 +373,7 @@ def smooth_l1_loss(
f"targets shape {targets.shape}."
)
diff = predictions - targets
diff = mx.abs(predictions - targets)
loss = mx.where(
diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
)