MLX
 
Loading...
Searching...
No Matches
utils.h File Reference
#include <exception>
#include <variant>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/dtype.h"
#include "mlx/stream.h"

Go to the source code of this file.

Classes

struct  mlx::core::StreamContext
 
struct  mlx::core::PrintFormatter
 
struct  mlx::core::finfo
 Holds information about floating-point types. More...
 

Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::env
 

Typedefs

using mlx::core::StreamOrDevice = std::variant<std::monostate, Stream, Device>
 

Functions

Stream mlx::core::to_stream (StreamOrDevice s)
 
PrintFormattermlx::core::get_global_formatter ()
 
void mlx::core::abort_with_exception (const std::exception &error)
 Print the exception and then abort.
 
Dtype mlx::core::result_type (const array &a, const array &b)
 The type from promoting the arrays' types with one another.
 
Dtype mlx::core::result_type (const array &a, const array &b, const array &c)
 
Dtype mlx::core::result_type (const std::vector< array > &arrays)
 
Shape mlx::core::broadcast_shapes (const Shape &s1, const Shape &s2)
 
int mlx::core::normalize_axis_index (int axis, int ndim, const std::string &msg_prefix="")
 Returns the axis normalized to be in the range [0, ndim).
 
std::ostream & mlx::core::operator<< (std::ostream &os, const Device &d)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const Stream &s)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const Dtype &d)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const Dtype::Kind &k)
 
std::ostream & mlx::core::operator<< (std::ostream &os, array a)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const std::vector< int > &v)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const std::vector< int64_t > &v)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const complex64_t &v)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const float16_t &v)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const bfloat16_t &v)
 
bool mlx::core::is_power_of_2 (int n)
 
int mlx::core::next_power_of_2 (int n)
 
int mlx::core::env::get_var (const char *name, int default_value)
 
int mlx::core::env::bfs_max_width ()
 
int mlx::core::env::max_ops_per_buffer (int default_value)
 
int mlx::core::env::max_mb_per_buffer (int default_value)
 
bool mlx::core::env::metal_fast_synch ()