From 1074674e32c0f2f873193c7e7f36eb82551aef5f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 6 Mar 2024 15:39:00 -0800 Subject: [PATCH] Add a maximum graph depth (#797) * add a maximum graph depth * remember how to use C++ --- mlx/array.h | 4 ++-- mlx/transforms.cpp | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mlx/array.h b/mlx/array.h index fe01cbfd7..f73dd66ff 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -274,7 +274,7 @@ class array { }; /** The depth of the array in the graph. Evaluated arrays have depth 0. */ - uint16_t graph_depth() const { + uint32_t graph_depth() const { return array_desc_->depth; } @@ -389,7 +389,7 @@ class array { uint32_t position{0}; // The depth of the array in the graph. - uint16_t depth{0}; + uint32_t depth{0}; explicit ArrayDesc(const std::vector& shape, Dtype dtype); diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 368c6cff4..1ba403ea1 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -17,6 +17,9 @@ namespace mlx::core { +// Maximum allowed graph depth for eval +constexpr uint32_t max_graph_depth = 100'000; + /* This class is only meant to be used in eval * for synchronizing with the main thread. */ class Synchronizer : public Primitive { @@ -116,6 +119,11 @@ void eval(const std::vector& outputs) { } }; + if (synchronizer.graph_depth() > max_graph_depth) { + throw std::runtime_error( + "[eval] Graph depth exceeded maximum allowed limit." + " Try evaluating the graph more frequently."); + } recurse(synchronizer, false); uintptr_t synch_id = synchronizer.primitive_id(); deps.insert({synch_id, std::shared_future{}});