mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Adds device context manager (#679)
This commit is contained in:
10
mlx/ops.cpp
10
mlx/ops.cpp
@@ -59,16 +59,6 @@ Dtype at_least_float(const Dtype& d) {
|
||||
|
||||
} // namespace
|
||||
|
||||
Stream to_stream(StreamOrDevice s) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
return default_stream(default_device());
|
||||
} else if (std::holds_alternative<Device>(s)) {
|
||||
return default_stream(std::get<Device>(s));
|
||||
} else {
|
||||
return std::get<Stream>(s);
|
||||
}
|
||||
}
|
||||
|
||||
array arange(
|
||||
double start,
|
||||
double stop,
|
||||
|
@@ -3,18 +3,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||
|
||||
Stream to_stream(StreamOrDevice s);
|
||||
|
||||
/** Creation operations */
|
||||
|
||||
/**
|
||||
|
@@ -7,6 +7,16 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
Stream to_stream(StreamOrDevice s) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
return default_stream(default_device());
|
||||
} else if (std::holds_alternative<Device>(s)) {
|
||||
return default_stream(std::get<Device>(s));
|
||||
} else {
|
||||
return std::get<Stream>(s);
|
||||
}
|
||||
}
|
||||
|
||||
void PrintFormatter::print(std::ostream& os, bool val) {
|
||||
if (capitalize_bool) {
|
||||
os << (val ? "True" : "False");
|
||||
|
26
mlx/utils.h
26
mlx/utils.h
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "dtype.h"
|
||||
@@ -9,6 +11,30 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||
Stream to_stream(StreamOrDevice s);
|
||||
|
||||
struct StreamContext {
|
||||
public:
|
||||
StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
throw std::runtime_error(
|
||||
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||
}
|
||||
auto _s = to_stream(s);
|
||||
set_default_device(_s.device);
|
||||
set_default_stream(_s);
|
||||
}
|
||||
|
||||
~StreamContext() {
|
||||
set_default_device(_stream.device);
|
||||
set_default_stream(_stream);
|
||||
}
|
||||
|
||||
private:
|
||||
Stream _stream;
|
||||
};
|
||||
|
||||
struct PrintFormatter {
|
||||
inline void print(std::ostream& os, bool val);
|
||||
inline void print(std::ostream& os, int16_t val);
|
||||
|
Reference in New Issue
Block a user