mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 11:31:21 +08:00
Update smooth_l1_loss in losses.py (#1974)
According the definition of smooth_l1_loss, the line diff = predictions - targets Should be updated to diff = mx.abs(predictions - targets) After the modification, the result is consistent with PyTorch smooth_l1_loss
This commit is contained in:
parent
f90206ad74
commit
95e335db7b
@ -373,7 +373,7 @@ def smooth_l1_loss(
|
|||||||
f"targets shape {targets.shape}."
|
f"targets shape {targets.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
diff = predictions - targets
|
diff = mx.abs(predictions - targets)
|
||||||
loss = mx.where(
|
loss = mx.where(
|
||||||
diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
|
diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user