Adds device context manager (#679)

This commit is contained in:
Diogo
2024-02-14 17:14:58 -05:00
committed by GitHub
parent ccf1645995
commit 35431a4ac8
15 changed files with 230 additions and 77 deletions

View File

@@ -59,16 +59,6 @@ Dtype at_least_float(const Dtype& d) {
} // namespace
Stream to_stream(StreamOrDevice s) {
if (std::holds_alternative<std::monostate>(s)) {
return default_stream(default_device());
} else if (std::holds_alternative<Device>(s)) {
return default_stream(std::get<Device>(s));
} else {
return std::get<Stream>(s);
}
}
array arange(
double start,
double stop,

View File

@@ -3,18 +3,14 @@
#pragma once
#include <optional>
#include <variant>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace mlx::core {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
Stream to_stream(StreamOrDevice s);
/** Creation operations */
/**

View File

@@ -7,6 +7,16 @@
namespace mlx::core {
Stream to_stream(StreamOrDevice s) {
if (std::holds_alternative<std::monostate>(s)) {
return default_stream(default_device());
} else if (std::holds_alternative<Device>(s)) {
return default_stream(std::get<Device>(s));
} else {
return std::get<Stream>(s);
}
}
void PrintFormatter::print(std::ostream& os, bool val) {
if (capitalize_bool) {
os << (val ? "True" : "False");

View File

@@ -2,6 +2,8 @@
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "dtype.h"
@@ -9,6 +11,30 @@
namespace mlx::core {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
Stream to_stream(StreamOrDevice s);
struct StreamContext {
public:
StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
auto _s = to_stream(s);
set_default_device(_s.device);
set_default_stream(_s);
}
~StreamContext() {
set_default_device(_stream.device);
set_default_stream(_stream);
}
private:
Stream _stream;
};
struct PrintFormatter {
inline void print(std::ostream& os, bool val);
inline void print(std::ostream& os, int16_t val);