21 if (std::holds_alternative<std::monostate>(s)) {
22 throw std::runtime_error(
23 "[StreamContext] Invalid argument, please specify a stream or device.");
40 inline void print(std::ostream& os,
bool val);
41 inline void print(std::ostream& os, int16_t val);
42 inline void print(std::ostream& os, uint16_t val);
43 inline void print(std::ostream& os, int32_t val);
44 inline void print(std::ostream& os, uint32_t val);
45 inline void print(std::ostream& os, int64_t val);
46 inline void print(std::ostream& os, uint64_t val);
49 inline void print(std::ostream& os,
float val);
85 const std::string& msg_prefix =
"");
92std::ostream&
operator<<(std::ostream& os,
const std::vector<int>& v);
93std::ostream&
operator<<(std::ostream& os,
const std::vector<int64_t>& v);
95 return os << v.real() << (v.imag() >= 0 ?
"+" :
"") << v.imag() <<
"j";
98 return os << static_cast<float>(v);
101 return os << static_cast<float>(v);
105 return ((n & (n - 1)) == 0) && n != 0;
112 return pow(2, std::ceil(std::log2(n)));
117int get_var(
const char* name,
int default_value);
120 static int bfs_max_width_ =
get_var(
"MLX_BFS_MAX_WIDTH", 20);
121 return bfs_max_width_;
125 static int max_ops_per_buffer_ =
get_var(
"MLX_MAX_OPS_PER_BUFFER", 10);
126 return max_ops_per_buffer_;
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
array operator<<(const array &a, const array &b)
int get_var(const char *name, int default_value)
int bfs_max_width()
Definition utils.h:119
int max_ops_per_buffer()
Definition utils.h:124
bool metal_fast_synch()
Definition utils.h:129
const Device & default_device()
int normalize_axis_index(int axis, int ndim, const std::string &msg_prefix="")
Returns the axis normalized to be in the range [0, ndim).
void set_default_device(const Device &d)
Stream to_stream(StreamOrDevice s)
Dtype promote_types(const Dtype &t1, const Dtype &t2)
int next_power_of_2(int n)
Definition utils.h:108
std::vector< ShapeElem > Shape
Definition array.h:21
Dtype result_type(const array &a, const array &b)
The type from promoting the arrays' types with one another.
Definition utils.h:69
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15
Stream default_stream(Device d)
Get the default stream for the given device.
struct _MLX_BFloat16 bfloat16_t
Definition half_types.h:34
bool is_power_of_2(int n)
Definition utils.h:104
void abort_with_exception(const std::exception &error)
Print the exception and then abort.
Shape broadcast_shapes(const Shape &s1, const Shape &s2)
void set_default_stream(Stream s)
Make the stream the default for its device.
struct _MLX_Float16 float16_t
Definition half_types.h:17
PrintFormatter & get_global_formatter()
Kind
Definition dtype.h:30
StreamContext(StreamOrDevice s)
Definition utils.h:20
~StreamContext()
Definition utils.h:30
Dtype dtype
Definition utils.h:63
float min
Definition utils.h:64
float max
Definition utils.h:65