MLX
Loading...
Searching...
No Matches
Classes | Namespaces | Typedefs | Functions | Variables
utils.h File Reference
#include <variant>
#include "array.h"
#include "device.h"
#include "dtype.h"
#include "stream.h"

Go to the source code of this file.

Classes

struct  mlx::core::StreamContext
 
struct  mlx::core::PrintFormatter
 

Namespaces

namespace  mlx
 
namespace  mlx::core
 

Typedefs

using mlx::core::StreamOrDevice = std::variant<std::monostate, Stream, Device>
 

Functions

Stream mlx::core::to_stream (StreamOrDevice s)
 
Dtype mlx::core::result_type (const array &a, const array &b)
 The type from promoting the arrays' types with one another.
 
Dtype mlx::core::result_type (const array &a, const array &b, const array &c)
 
Dtype mlx::core::result_type (const std::vector< array > &arrays)
 
std::vector< int > mlx::core::broadcast_shapes (const std::vector< int > &s1, const std::vector< int > &s2)
 
bool mlx::core::is_same_shape (const std::vector< array > &arrays)
 
template<typename T >
int mlx::core::check_shape_dim (const T dim)
 Returns the shape dimension if it's within allowed range.
 
int mlx::core::normalize_axis (int axis, int ndim)
 Returns the axis normalized to be in the range [0, ndim).
 
std::ostream & mlx::core::operator<< (std::ostream &os, const Device &d)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const Stream &s)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const Dtype &d)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const Dtype::Kind &k)
 
std::ostream & mlx::core::operator<< (std::ostream &os, array a)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const std::vector< int > &v)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const std::vector< size_t > &v)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const std::vector< int64_t > &v)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const complex64_t &v)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const float16_t &v)
 
std::ostream & mlx::core::operator<< (std::ostream &os, const bfloat16_t &v)
 
bool mlx::core::is_power_of_2 (int n)
 
int mlx::core::next_power_of_2 (int n)
 

Variables

PrintFormatter mlx::core::global_formatter