mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Adds device context manager (#679)
This commit is contained in:
26
mlx/utils.h
26
mlx/utils.h
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "dtype.h"
|
||||
@@ -9,6 +11,30 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||
Stream to_stream(StreamOrDevice s);
|
||||
|
||||
struct StreamContext {
|
||||
public:
|
||||
StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
throw std::runtime_error(
|
||||
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||
}
|
||||
auto _s = to_stream(s);
|
||||
set_default_device(_s.device);
|
||||
set_default_stream(_s);
|
||||
}
|
||||
|
||||
~StreamContext() {
|
||||
set_default_device(_stream.device);
|
||||
set_default_stream(_stream);
|
||||
}
|
||||
|
||||
private:
|
||||
Stream _stream;
|
||||
};
|
||||
|
||||
struct PrintFormatter {
|
||||
inline void print(std::ostream& os, bool val);
|
||||
inline void print(std::ostream& os, int16_t val);
|
||||
|
||||
Reference in New Issue
Block a user