MLX
 
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <exception>
6#include <variant>
7
8#include "mlx/array.h"
9#include "mlx/device.h"
10#include "mlx/dtype.h"
11#include "mlx/stream.h"
12
13namespace mlx::core {
14
15using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
18
20 public:
22 if (std::holds_alternative<std::monostate>(s)) {
23 throw std::runtime_error(
24 "[StreamContext] Invalid argument, please specify a stream or device.");
25 }
26 auto _s = to_stream(s);
27 set_default_device(_s.device);
29 }
30
32 set_default_device(_stream.device);
33 set_default_stream(_stream);
34 }
35
36 private:
37 Stream _stream;
38};
39
41 inline void print(std::ostream& os, bool val);
42 inline void print(std::ostream& os, int16_t val);
43 inline void print(std::ostream& os, uint16_t val);
44 inline void print(std::ostream& os, int32_t val);
45 inline void print(std::ostream& os, uint32_t val);
46 inline void print(std::ostream& os, int64_t val);
47 inline void print(std::ostream& os, uint64_t val);
48 inline void print(std::ostream& os, float16_t val);
49 inline void print(std::ostream& os, bfloat16_t val);
50 inline void print(std::ostream& os, float val);
51 inline void print(std::ostream& os, double val);
52 inline void print(std::ostream& os, complex64_t val);
53
54 bool capitalize_bool{false};
55};
56
58
60void abort_with_exception(const std::exception& error);
61
63struct finfo {
64 explicit finfo(Dtype dtype);
66 double min;
67 double max;
68};
69
71inline Dtype result_type(const array& a, const array& b) {
72 return promote_types(a.dtype(), b.dtype());
73}
74inline Dtype result_type(const array& a, const array& b, const array& c) {
75 return promote_types(result_type(a, b), c.dtype());
76}
77Dtype result_type(const std::vector<array>& arrays);
78
79Shape broadcast_shapes(const Shape& s1, const Shape& s2);
80
85 int axis,
86 int ndim,
87 const std::string& msg_prefix = "");
88
89std::ostream& operator<<(std::ostream& os, const Device& d);
90std::ostream& operator<<(std::ostream& os, const Stream& s);
91std::ostream& operator<<(std::ostream& os, const Dtype& d);
92std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
93std::ostream& operator<<(std::ostream& os, array a);
94std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
95std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
96inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
97 return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
98}
99inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
100 return os << static_cast<float>(v);
101}
102inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
103 return os << static_cast<float>(v);
104}
105
106inline bool is_power_of_2(int n) {
107 return ((n & (n - 1)) == 0) && n != 0;
108}
109
110inline int next_power_of_2(int n) {
111 if (is_power_of_2(n)) {
112 return n;
113 }
114 return pow(2, std::ceil(std::log2(n)));
115}
116
117namespace env {
118
119int get_var(const char* name, int default_value);
120
121inline int bfs_max_width() {
122 static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20);
123 return bfs_max_width_;
124}
125
126inline int max_ops_per_buffer(int default_value) {
127 static int max_ops_per_buffer_ =
128 get_var("MLX_MAX_OPS_PER_BUFFER", default_value);
129 return max_ops_per_buffer_;
130}
131
132inline int max_mb_per_buffer(int default_value) {
133 static int max_mb_per_buffer_ =
134 get_var("MLX_MAX_MB_PER_BUFFER", default_value);
135 return max_mb_per_buffer_;
136}
137
138inline bool metal_fast_synch() {
139 static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
140 return metal_fast_synch;
141}
142
143} // namespace env
144
145} // namespace mlx::core
Definition array.h:24
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
array operator<<(const array &a, const array &b)
Definition utils.h:117
int get_var(const char *name, int default_value)
int max_ops_per_buffer(int default_value)
Definition utils.h:126
int bfs_max_width()
Definition utils.h:121
bool metal_fast_synch()
Definition utils.h:138
int max_mb_per_buffer(int default_value)
Definition utils.h:132
Definition allocator.h:7
const Device & default_device()
int normalize_axis_index(int axis, int ndim, const std::string &msg_prefix="")
Returns the axis normalized to be in the range [0, ndim).
void set_default_device(const Device &d)
Stream to_stream(StreamOrDevice s)
Dtype promote_types(const Dtype &t1, const Dtype &t2)
int next_power_of_2(int n)
Definition utils.h:110
std::vector< ShapeElem > Shape
Definition array.h:21
Dtype result_type(const array &a, const array &b)
The type from promoting the arrays' types with one another.
Definition utils.h:71
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15
Stream default_stream(Device d)
Get the default stream for the given device.
struct _MLX_BFloat16 bfloat16_t
Definition half_types.h:34
bool is_power_of_2(int n)
Definition utils.h:106
void abort_with_exception(const std::exception &error)
Print the exception and then abort.
Shape broadcast_shapes(const Shape &s1, const Shape &s2)
void set_default_stream(Stream s)
Make the stream the default for its device.
struct _MLX_Float16 float16_t
Definition half_types.h:17
PrintFormatter & get_global_formatter()
Definition device.h:7
Definition dtype.h:13
Kind
Definition dtype.h:31
Definition utils.h:40
void print(std::ostream &os, uint32_t val)
void print(std::ostream &os, float val)
void print(std::ostream &os, bool val)
void print(std::ostream &os, double val)
void print(std::ostream &os, int16_t val)
void print(std::ostream &os, uint16_t val)
void print(std::ostream &os, complex64_t val)
void print(std::ostream &os, int64_t val)
void print(std::ostream &os, float16_t val)
void print(std::ostream &os, uint64_t val)
void print(std::ostream &os, int32_t val)
bool capitalize_bool
Definition utils.h:54
void print(std::ostream &os, bfloat16_t val)
StreamContext(StreamOrDevice s)
Definition utils.h:21
~StreamContext()
Definition utils.h:31
Definition stream.h:9
Definition complex.h:35
finfo(Dtype dtype)
double min
Definition utils.h:66
Dtype dtype
Definition utils.h:65
double max
Definition utils.h:67