2024-04-17 21:16:02 +08:00
|
|
|
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
#include <stdexcept>
|
|
|
|
|
|
|
|
#include "mlx/stream.h"
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
|
|
|
class Event {
|
|
|
|
public:
|
2024-06-13 13:06:49 +08:00
|
|
|
Event() = default;
|
2024-04-17 21:16:02 +08:00
|
|
|
|
|
|
|
Event(const Stream& steam);
|
|
|
|
|
2024-06-13 13:06:49 +08:00
|
|
|
// Wait for the event to be signaled at its current value
|
2024-04-17 21:16:02 +08:00
|
|
|
void wait();
|
|
|
|
|
|
|
|
// Signal the event at its current value
|
|
|
|
void signal();
|
|
|
|
|
2024-10-08 10:13:50 +08:00
|
|
|
// Check if the event has been signaled at its current value
|
|
|
|
bool is_signaled() const;
|
|
|
|
|
2024-04-17 21:16:02 +08:00
|
|
|
// Check if the event is valid
|
2024-06-13 13:06:49 +08:00
|
|
|
bool valid() const {
|
2024-04-17 21:16:02 +08:00
|
|
|
return event_ != nullptr;
|
2024-06-13 13:06:49 +08:00
|
|
|
}
|
2024-04-17 21:16:02 +08:00
|
|
|
|
2024-06-13 13:06:49 +08:00
|
|
|
uint64_t value() const {
|
2024-04-17 21:16:02 +08:00
|
|
|
return value_;
|
2024-06-13 13:06:49 +08:00
|
|
|
}
|
2024-04-17 21:16:02 +08:00
|
|
|
|
|
|
|
void set_value(uint64_t v) {
|
|
|
|
value_ = v;
|
2024-06-13 13:06:49 +08:00
|
|
|
}
|
2024-04-17 21:16:02 +08:00
|
|
|
|
2024-06-13 13:06:49 +08:00
|
|
|
const Stream& stream() const {
|
2024-04-17 21:16:02 +08:00
|
|
|
if (!valid()) {
|
|
|
|
throw std::runtime_error(
|
|
|
|
"[Event::stream] Cannot access stream on invalid event.");
|
|
|
|
}
|
|
|
|
return stream_;
|
2024-06-13 13:06:49 +08:00
|
|
|
}
|
2024-04-17 21:16:02 +08:00
|
|
|
|
2024-06-13 13:06:49 +08:00
|
|
|
const std::shared_ptr<void>& raw_event() const {
|
2024-04-17 21:16:02 +08:00
|
|
|
return event_;
|
2024-06-13 13:06:49 +08:00
|
|
|
}
|
2024-04-17 21:16:02 +08:00
|
|
|
|
|
|
|
private:
|
|
|
|
// Default constructed stream should never be used
|
|
|
|
// since the event is not yet valid
|
|
|
|
Stream stream_{0, Device::cpu};
|
2024-06-13 13:06:49 +08:00
|
|
|
std::shared_ptr<void> event_;
|
2024-04-17 21:16:02 +08:00
|
|
|
uint64_t value_{0};
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace mlx::core
|