Remove redundant assert in losses.py

This commit is contained in:
Zachary Schillaci 2023-12-24 11:28:03 -05:00
parent 7365d142a3
commit 185403538a

View File

@ -133,10 +133,6 @@ def mse_loss(
f"targets shape {targets.shape}."
)
assert (
predictions.shape == targets.shape
), f"Shape of predictions {predictions.shape} and targets {targets.shape} must match"
loss = mx.square(predictions - targets)
return _reduce(loss, reduction)