MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <variant>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/dtype.h"
10#include "mlx/stream.h"
11
12namespace mlx::core {
13
14using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
16
18 public:
20 if (std::holds_alternative<std::monostate>(s)) {
21 throw std::runtime_error(
22 "[StreamContext] Invalid argument, please specify a stream or device.");
23 }
24 auto _s = to_stream(s);
25 set_default_device(_s.device);
27 }
28
31 set_default_stream(_stream);
32 }
33
34 private:
35 Stream _stream;
36};
37
39 inline void print(std::ostream& os, bool val);
40 inline void print(std::ostream& os, int16_t val);
41 inline void print(std::ostream& os, uint16_t val);
42 inline void print(std::ostream& os, int32_t val);
43 inline void print(std::ostream& os, uint32_t val);
44 inline void print(std::ostream& os, int64_t val);
45 inline void print(std::ostream& os, uint64_t val);
46 inline void print(std::ostream& os, float16_t val);
47 inline void print(std::ostream& os, bfloat16_t val);
48 inline void print(std::ostream& os, float val);
49 inline void print(std::ostream& os, complex64_t val);
50
51 bool capitalize_bool{false};
52};
53
54extern PrintFormatter global_formatter;
55
57inline Dtype result_type(const array& a, const array& b) {
58 return promote_types(a.dtype(), b.dtype());
59}
60inline Dtype result_type(const array& a, const array& b, const array& c) {
61 return promote_types(result_type(a, b), c.dtype());
62}
63Dtype result_type(const std::vector<array>& arrays);
64
65Shape broadcast_shapes(const Shape& s1, const Shape& s2);
66
67bool is_same_shape(const std::vector<array>& arrays);
68
70template <typename T>
71int check_shape_dim(const T dim) {
72 constexpr bool is_signed = std::numeric_limits<T>::is_signed;
73 using U = std::conditional_t<is_signed, ssize_t, size_t>;
74 constexpr U min = static_cast<U>(std::numeric_limits<int>::min());
75 constexpr U max = static_cast<U>(std::numeric_limits<int>::max());
76
77 if ((is_signed && dim < min) || dim > max) {
78 throw std::invalid_argument(
79 "Shape dimension falls outside supported `int` range.");
80 }
81
82 return static_cast<int>(dim);
83}
84
90int normalize_axis(int axis, int ndim);
91
92std::ostream& operator<<(std::ostream& os, const Device& d);
93std::ostream& operator<<(std::ostream& os, const Stream& s);
94std::ostream& operator<<(std::ostream& os, const Dtype& d);
95std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
96std::ostream& operator<<(std::ostream& os, array a);
97std::ostream& operator<<(std::ostream& os, const Shape& v);
98std::ostream& operator<<(std::ostream& os, const Strides& v);
99std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
100inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
101 return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
102}
103inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
104 return os << static_cast<float>(v);
105}
106inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
107 return os << static_cast<float>(v);
108}
109
110inline bool is_power_of_2(int n) {
111 return ((n & (n - 1)) == 0) && n != 0;
112}
113
114inline int next_power_of_2(int n) {
115 if (is_power_of_2(n)) {
116 return n;
117 }
118 return pow(2, std::ceil(std::log2(n)));
119}
120
121namespace env {
122
123int get_var(const char* name, int default_value);
124
125inline int bfs_max_width() {
126 static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20);
127 return bfs_max_width_;
128}
129
130inline int max_ops_per_buffer() {
131 static int max_ops_per_buffer_ = get_var("MLX_MAX_OPS_PER_BUFFER", 10);
132 return max_ops_per_buffer_;
133}
134
135} // namespace env
136
137} // namespace mlx::core
Definition array.h:23
Dtype dtype() const
Get the arrays data type.
Definition array.h:130
array max(const array &a, bool keepdims, StreamOrDevice s={})
The maximum of all elements of the array.
array min(const array &a, bool keepdims, StreamOrDevice s={})
The minimum of all elements of the array.
array operator<<(const array &a, const array &b)
int get_var(const char *name, int default_value)
int bfs_max_width()
Definition utils.h:125
int max_ops_per_buffer()
Definition utils.h:130
Definition allocator.h:7
int normalize_axis(int axis, int ndim)
Returns the axis normalized to be in the range [0, ndim).
const Device & default_device()
void set_default_device(const Device &d)
Stream to_stream(StreamOrDevice s)
Dtype promote_types(const Dtype &t1, const Dtype &t2)
int next_power_of_2(int n)
Definition utils.h:114
int check_shape_dim(const T dim)
Returns the shape dimension if it's within allowed range.
Definition utils.h:71
Dtype result_type(const array &a, const array &b)
The type from promoting the arrays' types with one another.
Definition utils.h:57
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14
std::vector< int32_t > Shape
Definition array.h:20
Stream default_stream(Device d)
Get the default stream for the given device.
std::vector< size_t > Strides
Definition array.h:21
bool is_same_shape(const std::vector< array > &arrays)
bool is_power_of_2(int n)
Definition utils.h:110
Shape broadcast_shapes(const Shape &s1, const Shape &s2)
void set_default_stream(Stream s)
Make the stream the default for its device.
PrintFormatter global_formatter
Definition bf16.h:48
Definition fp16.h:21
Definition device.h:7
Definition dtype.h:13
Kind
Definition dtype.h:30
Definition utils.h:38
void print(std::ostream &os, uint32_t val)
void print(std::ostream &os, float val)
void print(std::ostream &os, bool val)
void print(std::ostream &os, int16_t val)
void print(std::ostream &os, uint16_t val)
void print(std::ostream &os, complex64_t val)
void print(std::ostream &os, int64_t val)
void print(std::ostream &os, float16_t val)
void print(std::ostream &os, uint64_t val)
void print(std::ostream &os, int32_t val)
bool capitalize_bool
Definition utils.h:51
void print(std::ostream &os, bfloat16_t val)
Definition utils.h:17
StreamContext(StreamOrDevice s)
Definition utils.h:19
~StreamContext()
Definition utils.h:29
Definition stream.h:9
Device device
Definition stream.h:11
Definition complex.h:34