21 if (std::holds_alternative<std::monostate>(s)) {
22 throw std::runtime_error(
23 "[StreamContext] Invalid argument, please specify a stream or device.");
40 inline void print(std::ostream& os,
bool val);
41 inline void print(std::ostream& os, int16_t val);
42 inline void print(std::ostream& os, uint16_t val);
43 inline void print(std::ostream& os, int32_t val);
44 inline void print(std::ostream& os, uint32_t val);
45 inline void print(std::ostream& os, int64_t val);
46 inline void print(std::ostream& os, uint64_t val);
49 inline void print(std::ostream& os,
float val);
50 inline void print(std::ostream& os,
double val);
86 const std::string& msg_prefix =
"");
93std::ostream&
operator<<(std::ostream& os,
const std::vector<int>& v);
94std::ostream&
operator<<(std::ostream& os,
const std::vector<int64_t>& v);
96 return os << v.real() << (v.imag() >= 0 ?
"+" :
"") << v.imag() <<
"j";
99 return os << static_cast<float>(v);
102 return os << static_cast<float>(v);
106 return ((n & (n - 1)) == 0) && n != 0;
113 return pow(2, std::ceil(std::log2(n)));
118int get_var(
const char* name,
int default_value);
121 static int bfs_max_width_ =
get_var(
"MLX_BFS_MAX_WIDTH", 20);
122 return bfs_max_width_;
126 static int max_ops_per_buffer_ =
127 get_var(
"MLX_MAX_OPS_PER_BUFFER", default_value);
128 return max_ops_per_buffer_;
132 static int max_mb_per_buffer_ =
133 get_var(
"MLX_MAX_MB_PER_BUFFER", default_value);
134 return max_mb_per_buffer_;
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
array operator<<(const array &a, const array &b)
int get_var(const char *name, int default_value)
int max_ops_per_buffer(int default_value)
Definition utils.h:125
int bfs_max_width()
Definition utils.h:120
bool metal_fast_synch()
Definition utils.h:137
int max_mb_per_buffer(int default_value)
Definition utils.h:131
const Device & default_device()
int normalize_axis_index(int axis, int ndim, const std::string &msg_prefix="")
Returns the axis normalized to be in the range [0, ndim).
void set_default_device(const Device &d)
Stream to_stream(StreamOrDevice s)
Dtype promote_types(const Dtype &t1, const Dtype &t2)
int next_power_of_2(int n)
Definition utils.h:109
std::vector< ShapeElem > Shape
Definition array.h:21
Dtype result_type(const array &a, const array &b)
The type from promoting the arrays' types with one another.
Definition utils.h:70
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15
Stream default_stream(Device d)
Get the default stream for the given device.
struct _MLX_BFloat16 bfloat16_t
Definition half_types.h:34
bool is_power_of_2(int n)
Definition utils.h:105
void abort_with_exception(const std::exception &error)
Print the exception and then abort.
Shape broadcast_shapes(const Shape &s1, const Shape &s2)
void set_default_stream(Stream s)
Make the stream the default for its device.
struct _MLX_Float16 float16_t
Definition half_types.h:17
PrintFormatter & get_global_formatter()
Kind
Definition dtype.h:31
StreamContext(StreamOrDevice s)
Definition utils.h:20
~StreamContext()
Definition utils.h:30
double min
Definition utils.h:65
Dtype dtype
Definition utils.h:64
double max
Definition utils.h:66