diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 91743ec04..90eeaa95a 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -727,7 +727,11 @@ void compile_fuse( } }; - if (arr.has_primitive()) { + // This will be the result of the fused operation so it needs + // a) to not be already computed ie have a primitive + // b) that primitive to not be a broadcast since it will unnecessarily + // cast to a contiguous array potentially blowing up memory + if (arr.has_primitive() && !is_broadcast(arr.primitive())) { Stream s = arr.primitive().stream(); recurse(arr, 0, s, arr.shape()); }