Adds device context manager (#679)

This commit is contained in:
Diogo
2024-02-14 17:14:58 -05:00
committed by GitHub
parent ccf1645995
commit 35431a4ac8
15 changed files with 230 additions and 77 deletions

View File

@@ -2,6 +2,8 @@
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "dtype.h"
@@ -9,6 +11,30 @@
namespace mlx::core {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
Stream to_stream(StreamOrDevice s);
struct StreamContext {
public:
StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
auto _s = to_stream(s);
set_default_device(_s.device);
set_default_stream(_s);
}
~StreamContext() {
set_default_device(_stream.device);
set_default_stream(_stream);
}
private:
Stream _stream;
};
struct PrintFormatter {
inline void print(std::ostream& os, bool val);
inline void print(std::ostream& os, int16_t val);