21    if (std::holds_alternative<std::monostate>(s)) {
 
   22      throw std::runtime_error(
 
   23          "[StreamContext] Invalid argument, please specify a stream or device.");
 
 
 
   40  inline void print(std::ostream& os, 
bool val);
 
   41  inline void print(std::ostream& os, int16_t val);
 
   42  inline void print(std::ostream& os, uint16_t val);
 
   43  inline void print(std::ostream& os, int32_t val);
 
   44  inline void print(std::ostream& os, uint32_t val);
 
   45  inline void print(std::ostream& os, int64_t val);
 
   46  inline void print(std::ostream& os, uint64_t val);
 
   49  inline void print(std::ostream& os, 
float val);
 
   50  inline void print(std::ostream& os, 
double val);
 
 
   86    const std::string& msg_prefix = 
"");
 
   93std::ostream& 
operator<<(std::ostream& os, 
const std::vector<int>& v);
 
   94std::ostream& 
operator<<(std::ostream& os, 
const std::vector<int64_t>& v);
 
   96  return os << v.real() << (v.imag() >= 0 ? 
"+" : 
"") << v.imag() << 
"j";
 
 
   99  return os << static_cast<float>(v);
 
 
  102  return os << static_cast<float>(v);
 
 
  106  return ((n & (n - 1)) == 0) && n != 0;
 
 
  113  return pow(2, std::ceil(std::log2(n)));
 
 
  118int get_var(
const char* name, 
int default_value);
 
  121  static int bfs_max_width_ = 
get_var(
"MLX_BFS_MAX_WIDTH", 20);
 
  122  return bfs_max_width_;
 
 
  126  static int max_ops_per_buffer_ =
 
  127      get_var(
"MLX_MAX_OPS_PER_BUFFER", default_value);
 
  128  return max_ops_per_buffer_;
 
 
  132  static int max_mb_per_buffer_ =
 
  133      get_var(
"MLX_MAX_MB_PER_BUFFER", default_value);
 
  134  return max_mb_per_buffer_;
 
 
 
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
 
array operator<<(const array &a, const array &b)
 
int get_var(const char *name, int default_value)
 
int max_ops_per_buffer(int default_value)
Definition utils.h:125
 
int bfs_max_width()
Definition utils.h:120
 
bool metal_fast_synch()
Definition utils.h:137
 
int max_mb_per_buffer(int default_value)
Definition utils.h:131
 
const Device & default_device()
 
int normalize_axis_index(int axis, int ndim, const std::string &msg_prefix="")
Returns the axis normalized to be in the range [0, ndim).
 
void set_default_device(const Device &d)
 
Stream to_stream(StreamOrDevice s)
 
Dtype promote_types(const Dtype &t1, const Dtype &t2)
 
int next_power_of_2(int n)
Definition utils.h:109
 
std::vector< ShapeElem > Shape
Definition array.h:21
 
Dtype result_type(const array &a, const array &b)
The type from promoting the arrays' types with one another.
Definition utils.h:70
 
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15
 
Stream default_stream(Device d)
Get the default stream for the given device.
 
struct _MLX_BFloat16 bfloat16_t
Definition half_types.h:34
 
bool is_power_of_2(int n)
Definition utils.h:105
 
void abort_with_exception(const std::exception &error)
Print the exception and then abort.
 
Shape broadcast_shapes(const Shape &s1, const Shape &s2)
 
void set_default_stream(Stream s)
Make the stream the default for its device.
 
struct _MLX_Float16 float16_t
Definition half_types.h:17
 
PrintFormatter & get_global_formatter()
 
Kind
Definition dtype.h:31
 
StreamContext(StreamOrDevice s)
Definition utils.h:20
 
~StreamContext()
Definition utils.h:30
 
double min
Definition utils.h:65
 
Dtype dtype
Definition utils.h:64
 
double max
Definition utils.h:66