Remove redundant assert in losses.py (#281)

This commit is contained in:
Zach Schillaci 2023-12-24 11:39:08 -05:00 committed by GitHub
parent 7365d142a3
commit 22fee5a383
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -133,10 +133,6 @@ def mse_loss(
f"targets shape {targets.shape}."
)
assert (
predictions.shape == targets.shape
), f"Shape of predictions {predictions.shape} and targets {targets.shape} must match"
loss = mx.square(predictions - targets)
return _reduce(loss, reduction)