mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 19:11:17 +08:00
Update smooth_l1_loss in losses.py (#1974)
According the definition of smooth_l1_loss, the line diff = predictions - targets Should be updated to diff = mx.abs(predictions - targets) After the modification, the result is consistent with PyTorch smooth_l1_loss
This commit is contained in:
parent
f90206ad74
commit
95e335db7b
@ -373,7 +373,7 @@ def smooth_l1_loss(
|
||||
f"targets shape {targets.shape}."
|
||||
)
|
||||
|
||||
diff = predictions - targets
|
||||
diff = mx.abs(predictions - targets)
|
||||
loss = mx.where(
|
||||
diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user