20 if (std::holds_alternative<std::monostate>(s)) {
21 throw std::runtime_error(
22 "[StreamContext] Invalid argument, please specify a stream or device.");
39 inline void print(std::ostream& os,
bool val);
40 inline void print(std::ostream& os, int16_t val);
41 inline void print(std::ostream& os, uint16_t val);
42 inline void print(std::ostream& os, int32_t val);
43 inline void print(std::ostream& os, uint32_t val);
44 inline void print(std::ostream& os, int64_t val);
45 inline void print(std::ostream& os, uint64_t val);
48 inline void print(std::ostream& os,
float val);
72 constexpr bool is_signed = std::numeric_limits<T>::is_signed;
73 using U = std::conditional_t<is_signed, ssize_t, size_t>;
74 constexpr U
min =
static_cast<U
>(std::numeric_limits<int>::min());
75 constexpr U
max =
static_cast<U
>(std::numeric_limits<int>::max());
77 if ((is_signed && dim <
min) || dim >
max) {
78 throw std::invalid_argument(
79 "Shape dimension falls outside supported `int` range.");
82 return static_cast<int>(dim);
99std::ostream&
operator<<(std::ostream& os,
const std::vector<int64_t>& v);
101 return os << v.real() << (v.imag() >= 0 ?
"+" :
"") << v.imag() <<
"j";
104 return os << static_cast<float>(v);
107 return os << static_cast<float>(v);
111 return ((n & (n - 1)) == 0) && n != 0;
118 return pow(2, std::ceil(std::log2(n)));
123int get_var(
const char* name,
int default_value);
126 static int bfs_max_width_ =
get_var(
"MLX_BFS_MAX_WIDTH", 20);
127 return bfs_max_width_;
131 static int max_ops_per_buffer_ =
get_var(
"MLX_MAX_OPS_PER_BUFFER", 10);
132 return max_ops_per_buffer_;
Dtype dtype() const
Get the arrays data type.
Definition array.h:130
array max(const array &a, bool keepdims, StreamOrDevice s={})
The maximum of all elements of the array.
array min(const array &a, bool keepdims, StreamOrDevice s={})
The minimum of all elements of the array.
array operator<<(const array &a, const array &b)
int get_var(const char *name, int default_value)
int bfs_max_width()
Definition utils.h:125
int max_ops_per_buffer()
Definition utils.h:130
int normalize_axis(int axis, int ndim)
Returns the axis normalized to be in the range [0, ndim).
const Device & default_device()
void set_default_device(const Device &d)
Stream to_stream(StreamOrDevice s)
Dtype promote_types(const Dtype &t1, const Dtype &t2)
int next_power_of_2(int n)
Definition utils.h:114
int check_shape_dim(const T dim)
Returns the shape dimension if it's within allowed range.
Definition utils.h:71
Dtype result_type(const array &a, const array &b)
The type from promoting the arrays' types with one another.
Definition utils.h:57
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14
std::vector< int32_t > Shape
Definition array.h:20
Stream default_stream(Device d)
Get the default stream for the given device.
std::vector< size_t > Strides
Definition array.h:21
bool is_same_shape(const std::vector< array > &arrays)
bool is_power_of_2(int n)
Definition utils.h:110
Shape broadcast_shapes(const Shape &s1, const Shape &s2)
void set_default_stream(Stream s)
Make the stream the default for its device.
PrintFormatter global_formatter
Kind
Definition dtype.h:30
StreamContext(StreamOrDevice s)
Definition utils.h:19
~StreamContext()
Definition utils.h:29
Device device
Definition stream.h:11