diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 75159335f4..61a44c8120 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -293,26 +293,35 @@ array reshape( for (int i = 0; i < shape.size(); ++i) { if (shape[i] == -1) { if (infer_idx >= 0) { - throw std::invalid_argument("Reshape can only infer one dimension."); + throw std::invalid_argument( + "[reshape] Reshape can only infer one dimension."); } infer_idx = i; } else { size *= shape[i]; } } + + // Infer the shape if (size > 0) { auto q_and_r = std::ldiv(a.size(), size); if (infer_idx >= 0) { shape[infer_idx] = q_and_r.quot; size *= q_and_r.quot; } + } else if (infer_idx >= 0) { + throw std::invalid_argument( + "[reshape] Cannot infer the shape of an empty array"); } + + // Check the the reshaping is valid if (a.size() != size) { std::ostringstream msg; - msg << "Cannot reshape array of size " << a.size() << " into shape " - << shape << "."; + msg << "[reshape] Cannot reshape array of size " << a.size() + << " into shape " << shape << "."; throw std::invalid_argument(msg.str()); } + return array( shape, a.dtype(), std::make_unique(to_stream(s), shape), {a}); }