mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	only interrupt during an eval
This commit is contained in:
		@@ -2,6 +2,7 @@
 | 
				
			|||||||
#include <algorithm>
 | 
					#include <algorithm>
 | 
				
			||||||
#include <deque>
 | 
					#include <deque>
 | 
				
			||||||
#include <future>
 | 
					#include <future>
 | 
				
			||||||
 | 
					#include <mutex>
 | 
				
			||||||
#include <numeric>
 | 
					#include <numeric>
 | 
				
			||||||
#include <set>
 | 
					#include <set>
 | 
				
			||||||
#include <sstream>
 | 
					#include <sstream>
 | 
				
			||||||
@@ -35,10 +36,41 @@ class Synchronizer : public Primitive {
 | 
				
			|||||||
  DEFINE_PRINT(Synchronize);
 | 
					  DEFINE_PRINT(Synchronize);
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::atomic<bool>& interrupt_flag() {
 | 
					class Interrupt {
 | 
				
			||||||
  static std::atomic<bool> interrupt_{false};
 | 
					 private:
 | 
				
			||||||
 | 
					  static std::mutex mutex_;
 | 
				
			||||||
 | 
					  static bool eval_running_;
 | 
				
			||||||
 | 
					  static bool interrupt_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  Interrupt() {
 | 
				
			||||||
 | 
					    std::unique_lock lk(mutex_);
 | 
				
			||||||
 | 
					    eval_running_ = true;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static bool interrupt() {
 | 
				
			||||||
 | 
					    std::unique_lock lk(mutex_);
 | 
				
			||||||
 | 
					    if (eval_running_) {
 | 
				
			||||||
 | 
					      interrupt_ = true;
 | 
				
			||||||
 | 
					      return true;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return false;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static bool interrupted() {
 | 
				
			||||||
 | 
					    std::unique_lock lk(mutex_);
 | 
				
			||||||
    return interrupt_;
 | 
					    return interrupt_;
 | 
				
			||||||
}
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ~Interrupt() {
 | 
				
			||||||
 | 
					    std::unique_lock lk(mutex_);
 | 
				
			||||||
 | 
					    eval_running_ = false;
 | 
				
			||||||
 | 
					    interrupt_ = false;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					std::mutex Interrupt::mutex_{};
 | 
				
			||||||
 | 
					bool Interrupt::eval_running_ = false;
 | 
				
			||||||
 | 
					bool Interrupt::interrupt_ = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Initialize the static tracing members from transforms_impl.h
 | 
					// Initialize the static tracing members from transforms_impl.h
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
@@ -50,6 +82,8 @@ std::vector<char> detail::InTracing::trace_stack{};
 | 
				
			|||||||
int detail::RetainGraph::tracing_counter{0};
 | 
					int detail::RetainGraph::tracing_counter{0};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
array eval_impl(std::vector<array> outputs, bool async) {
 | 
					array eval_impl(std::vector<array> outputs, bool async) {
 | 
				
			||||||
 | 
					  Interrupt interrupt;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::deque<array> tape;
 | 
					  std::deque<array> tape;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Make an effort to choose a good output stream
 | 
					  // Make an effort to choose a good output stream
 | 
				
			||||||
@@ -255,8 +289,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
 | 
				
			|||||||
      arr.detach();
 | 
					      arr.detach();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (interrupt_flag()) {
 | 
					    if (Interrupt::interrupted()) {
 | 
				
			||||||
      interrupt_flag() = false;
 | 
					 | 
				
			||||||
      synchronizer.attach_event(Event{stream});
 | 
					      synchronizer.attach_event(Event{stream});
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -274,8 +307,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
 | 
				
			|||||||
  return synchronizer;
 | 
					  return synchronizer;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void interrupt_eval() {
 | 
					bool interrupt_eval() {
 | 
				
			||||||
  interrupt_flag() = true;
 | 
					  return Interrupt::interrupt();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void async_eval(std::vector<array> outputs) {
 | 
					void async_eval(std::vector<array> outputs) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -23,9 +23,12 @@ void eval(Arrays&&... outputs) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Interrupt an ongoing eval. Leaves the graph in a valid state.
 | 
					 * Interrupt an ongoing eval.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Leaves the graph in a valid state. Returns true if an ongoing eval was
 | 
				
			||||||
 | 
					 * interrupted and false otherwise.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
void interrupt_eval();
 | 
					bool interrupt_eval();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 *  Computes the output and vector-Jacobian product (VJP) of a function.
 | 
					 *  Computes the output and vector-Jacobian product (VJP) of a function.
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -97,7 +97,8 @@ TEST_CASE("test interrupt eval") {
 | 
				
			|||||||
    x = x + 1;
 | 
					    x = x + 1;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  std::thread t([x]() { eval(x); });
 | 
					  std::thread t([x]() { eval(x); });
 | 
				
			||||||
  interrupt_eval();
 | 
					  while (!interrupt_eval()) {
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  t.join();
 | 
					  t.join();
 | 
				
			||||||
  // Check that x is not evaluated
 | 
					  // Check that x is not evaluated
 | 
				
			||||||
  CHECK(!x.is_available());
 | 
					  CHECK(!x.is_available());
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user