- METAL_FUNC stride_t elem_to_loc
+ METAL_FUNC IdxT elem_to_loc
(
uint3 elem ,
@@ -430,7 +432,7 @@ template<typename stride_t >
- constant const stride_t * strides ,
+ constant const StrideT * strides ,
@@ -442,62 +444,62 @@ template<typename stride_t >
-
-◆ elem_to_loc_1()
+
+◆ elem_to_loc_1()
-template<typename stride_t >
+template<typename StrideT , typename IdxT = StrideT>
- METAL_FUNC stride_t elem_to_loc_1
+ METAL_FUNC IdxT elem_to_loc_1
(
uint elem ,
- constant const stride_t & stride )
+ constant const StrideT & stride )
-
-◆ elem_to_loc_2()
+
+◆ elem_to_loc_2()
-template<typename stride_t >
+template<typename StrideT , typename IdxT = StrideT>
- METAL_FUNC stride_t elem_to_loc_2
+ METAL_FUNC IdxT elem_to_loc_2
(
uint2 elem ,
- constant const stride_t strides [2] )
+ constant const StrideT strides [2] )
-
-◆ elem_to_loc_2_nd()
+
+◆ elem_to_loc_2_nd()
-template<typename stride_t >
+template<typename StrideT , typename IdxT = StrideT>
- METAL_FUNC ulong2 elem_to_loc_2_nd
+ METAL_FUNC vec< IdxT, 2 > elem_to_loc_2_nd
(
uint3 elem ,
@@ -509,12 +511,12 @@ template<typename stride_t >
- constant const stride_t * a_strides ,
+ constant const StrideT * a_strides ,
- constant const stride_t * b_strides ,
+ constant const StrideT * b_strides ,
@@ -526,37 +528,39 @@ template<typename stride_t >
-
-◆ elem_to_loc_3()
+
+◆ elem_to_loc_3()
-template<typename stride_t >
+template<typename StrideT , typename IdxT = StrideT>
- METAL_FUNC stride_t elem_to_loc_3
+ METAL_FUNC IdxT elem_to_loc_3
(
uint3 elem ,
- constant const stride_t strides [3] )
+ constant const StrideT strides [3] )
-
-◆ elem_to_loc_3_nd()
+
+◆ elem_to_loc_3_nd()
+
+template<typename IdxT = size_t>
- METAL_FUNC ulong3 elem_to_loc_3_nd
+ METAL_FUNC vec< IdxT, 3 > elem_to_loc_3_nd
(
uint3 elem ,
@@ -600,9 +604,9 @@ template<typename stride_t >
diff --git a/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html b/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html
index ddb1325ce..66c8011c3 100644
--- a/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html
+++ b/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html
@@ -96,497 +96,546 @@ $(function(){ initResizable(false); });
-
-
-
-
-
+
+
+
+
+
-
+
+
+
-
-
-
-
18 static const constant U
max = metal::numeric_limits<U>::max();
-
19 static const constant U
min = metal::numeric_limits<U>::min();
-
20 static const constant U
finite_max = metal::numeric_limits<U>::max();
-
21 static const constant U
finite_min = metal::numeric_limits<U>::min();
-
+
+
+
+
+
+
+
+
24 static const constant U
max = metal::numeric_limits<U>::max();
+
25 static const constant U
min = metal::numeric_limits<U>::min();
+
26 static const constant U
finite_max = metal::numeric_limits<U>::max();
+
27 static const constant U
finite_min = metal::numeric_limits<U>::min();
+
-
-
-
24 #define instantiate_default_limit(type) \
-
-
26 struct Limits<type> { \
-
27 static constexpr constant type max = metal::numeric_limits<type>::max(); \
-
28 static constexpr constant type min = metal::numeric_limits<type>::min(); \
-
29 static constexpr constant type finite_max = \
-
30 metal::numeric_limits<type>::max(); \
-
31 static constexpr constant type finite_min = \
-
32 metal::numeric_limits<type>::min(); \
-
+
+
+
30 #define instantiate_default_limit(type) \
+
+
32 struct Limits<type> { \
+
33 static constexpr constant type max = metal::numeric_limits<type>::max(); \
+
34 static constexpr constant type min = metal::numeric_limits<type>::min(); \
+
35 static constexpr constant type finite_max = \
+
36 metal::numeric_limits<type>::max(); \
+
37 static constexpr constant type finite_min = \
+
38 metal::numeric_limits<type>::min(); \
+
-
-
-
-
-
-
-
-
-
-
-
-
44 #define instantiate_float_limit(type) \
-
-
46 struct Limits<type> { \
-
47 static constexpr constant type max = \
-
48 metal::numeric_limits<type>::infinity(); \
-
49 static constexpr constant type min = \
-
50 -metal::numeric_limits<type>::infinity(); \
-
51 static constexpr constant type finite_max = \
-
52 metal::numeric_limits<type>::max(); \
-
53 static constexpr constant type finite_min = \
-
54 -metal::numeric_limits<type>::max(); \
-
-
-
-
-
-
-
-
-
-
-
63 static constexpr constant
bool max =
true ;
-
64 static constexpr constant
bool min =
false ;
-
+
+
+
+
+
+
+
+
+
+
+
+
50 #define instantiate_float_limit(type) \
+
+
52 struct Limits<type> { \
+
53 static constexpr constant type max = \
+
54 metal::numeric_limits<type>::infinity(); \
+
55 static constexpr constant type min = \
+
56 -metal::numeric_limits<type>::infinity(); \
+
57 static constexpr constant type finite_max = \
+
58 metal::numeric_limits<type>::max(); \
+
59 static constexpr constant type finite_min = \
+
60 -metal::numeric_limits<type>::max(); \
+
+
+
+
+
-
-
-
70 metal::numeric_limits<float>::infinity(),
-
71 metal::numeric_limits<float>::infinity());
-
-
73 -metal::numeric_limits<float>::infinity(),
-
74 -metal::numeric_limits<float>::infinity());
-
+
+
69 static constexpr constant
bool max =
true ;
+
70 static constexpr constant
bool min =
false ;
+
+
+
+
+
+
+
+
76 metal::numeric_limits<float>::infinity(),
+
77 metal::numeric_limits<float>::infinity());
+
+
79 -metal::numeric_limits<float>::infinity(),
+
80 -metal::numeric_limits<float>::infinity());
+
-
-
-
-
81 #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
-
-
-
86 template <
typename str
id e_t>
-
-
-
-
89 constant
const int * shape,
-
90 constant
const stride_t* strides,
-
-
-
93 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
-
94 loc += (elem % shape[i]) * strides[i];
-
-
-
-
+
+
+
87 #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
+
+
+
+
92 template <
typename Str
id eT,
typename IdxT = Str
id eT>
+
+
+
+
95 constant
const int * shape,
+
96 constant
const StrideT* strides,
+
+
+
99 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
100 loc += (elem % shape[i]) * IdxT(strides[i]);
+
+
+
+
-
-
100 template <
typename str
id e_t>
-
-
-
-
103 constant
const int * shape,
-
104 constant
const stride_t* strides,
-
-
-
107 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
-
108 loc += (elem % shape[i]) * strides[i];
-
-
-
-
+
+
106 template <
typename Str
id eT,
typename IdxT = Str
id eT>
+
+
+
+
109 constant
const int * shape,
+
110 constant
const StrideT* strides,
+
+
+
113 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
114 loc += (elem % shape[i]) * IdxT(strides[i]);
+
+
+
+
-
-
-
115 template <
typename str
id e_t>
-
-
-
-
118 constant
const int * shape,
-
119 constant
const stride_t* strides,
-
-
121 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
-
122 for (
int d = ndim - 3; d >= 0; --d) {
-
123 loc += (elem.z % shape[d]) * strides[d];
-
-
-
-
+
+
+
121 template <
typename Str
id eT,
typename IdxT = Str
id eT>
+
+
+
+
124 constant
const int * shape,
+
125 constant
const StrideT* strides,
+
+
+
128 elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
+
129 for (
int d = ndim - 3; d >= 0; --d) {
+
130 loc += (elem.z % shape[d]) * IdxT(strides[d]);
+
+
+
+
-
-
-
-
132 template <
typename str
id e_t>
-
-
133 METAL_FUNC stride_t
elem_to_loc_1 (uint elem, constant
const stride_t& stride) {
-
134 return elem * stride;
-
+
+
+
+
139 template <
typename Str
id eT,
typename IdxT = Str
id eT>
+
+
+
141 return elem * IdxT(stride);
+
-
-
137 template <
typename str
id e_t>
-
-
-
-
140 return elem.x * strides[1] + elem.y * strides[0];
-
-
-
-
143 template <
typename str
id e_t>
-
+
+
144 template <
typename Str
id eT,
typename IdxT = Str
id eT>
-
-
146 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
+
145 METAL_FUNC IdxT
elem_to_loc_2 (uint2 elem, constant
const StrideT strides[2]) {
+
146 return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
-
-
-
152 template <
typename str
id e_t>
-
-
-
-
155 constant
const int * shape,
-
156 constant
const stride_t* a_strides,
-
157 constant
const stride_t* b_strides,
-
-
-
160 ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
-
161 ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
-
162 for (
int d = ndim - 3; d >= 0; --d) {
-
163 uint l = elem.z % shape[d];
-
164 loc.x += l * a_strides[d];
-
165 loc.y += l * b_strides[d];
-
-
-
-
+
149 template <
typename Str
id eT,
typename IdxT = Str
id eT>
+
+
150 METAL_FUNC IdxT
elem_to_loc_3 (uint3 elem, constant
const StrideT strides[3]) {
+
151 return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
+
152 elem.z * IdxT(strides[0]);
+
-
-
-
-
-
173 constant
const int * shape,
-
174 constant
const size_t * a_strides,
-
175 constant
const size_t * b_strides,
-
176 constant
const size_t * c_strides,
-
-
-
179 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
-
180 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
-
181 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
-
182 for (
int d = ndim - 3; d >= 0; --d) {
-
183 uint l = elem.z % shape[d];
-
184 loc.x += l * a_strides[d];
-
185 loc.y += l * b_strides[d];
-
186 loc.z += l * c_strides[d];
-
-
-
-
+
+
+
+
158 template <
typename Str
id eT,
typename IdxT = Str
id eT>
+
+
+
+
161 constant
const int * shape,
+
162 constant
const StrideT* a_strides,
+
163 constant
const StrideT* b_strides,
+
+
+
+
167 elem.x * IdxT(a_strides[ndim - 1]) +
+
168 IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
+
+
170 elem.x * IdxT(b_strides[ndim - 1]) +
+
171 elem.y * IdxT(b_strides[ndim - 2]))};
+
172 for (
int d = ndim - 3; d >= 0; --d) {
+
173 uint l = elem.z % shape[d];
+
174 loc.x += l * IdxT(a_strides[d]);
+
175 loc.y += l * IdxT(b_strides[d]);
+
+
+
+
-
-
-
-
196 template <
int dim,
typename offset_t =
size_t >
-
-
-
-
-
-
-
-
202 void next (
const constant
int * shape,
const constant
size_t * strides) {
-
-
204 offset += strides[dim - 1];
-
-
206 if (
index >= shape[dim - 1]) {
-
-
-
-
-
+
+
181 template <
typename IdxT =
size_t >
+
+
+
+
184 constant
const int * shape,
+
185 constant
const size_t * a_strides,
+
186 constant
const size_t * b_strides,
+
187 constant
const size_t * c_strides,
+
+
+
190 elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
+
191 elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
+
192 elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
+
193 for (
int d = ndim - 3; d >= 0; --d) {
+
194 uint l = elem.z % shape[d];
+
195 loc.x += l * IdxT(a_strides[d]);
+
196 loc.y += l * IdxT(b_strides[d]);
+
197 loc.z += l * IdxT(c_strides[d]);
+
+
+
+
-
-
-
213 void next (
int n,
const constant
int * shape,
const constant
size_t * strides) {
-
-
215 offset += n * strides[dim - 1];
-
-
217 if (
index >= shape[dim - 1]) {
-
218 int extra =
index - shape[dim - 1];
-
-
-
-
-
223 next (extra, shape, strides);
-
-
-
+
+
+
+
207 template <
int DIM,
typename OffsetT =
size_t ,
bool General = true>
+
+
+
+
+
+
+
+
+
+
+
216 void next (
const constant
int * shape,
const constant
size_t * strides) {
+
+
+
+
+
+
+
+
+
+
+
-
-
+
-
229 location (offset_t,
const constant
int *,
const constant
size_t *,
int ) {
-
-
-
-
-
-
-
234 template <
typename offset_t>
-
-
-
-
-
-
238 void next (
const constant
int *,
const constant
size_t * strides) {
-
-
-
-
-
-
242 void next (
int n,
const constant
int *,
const constant
size_t * strides) {
-
-
-
-
-
-
-
247 location (offset_t,
const constant
int *,
const constant
size_t *,
int ) {
-
-
-
-
+
229 void next (
int n,
const constant
int * shape,
const constant
size_t * strides) {
+
+
+
+
+
+
+
+
+
238 if (extra >= shape[
dim - 1]) {
+
+
240 extra = extra % shape[
dim - 1];
+
+
+
+
+
+
+
247 next (extra, shape, strides);
+
+
+
-
252 template <
typename offset_t>
-
-
-
254 void next (
const constant
int *,
const constant
size_t *) {}
-
255 void next (
int ,
const constant
int *,
const constant
size_t *) {}
+
+
+
-
-
-
-
259 const constant
int * shape,
-
260 const constant
size_t * strides,
-
-
-
+
257 template <
typename OffsetT>
+
+
+
+
+
+
+
+
+
+
265 void next (
const constant
int * shape,
const constant
size_t * strides) {
+
+
+
+
+
270 offset += OffsetT(strides[0]);
+
+
-
-
-
-
-
-
271 template <
typename T,
typename U>
-
-
-
273 return (N + M - 1) / M;
-
-
-
-
-
-
-
278 float xp1 = 1.0f + x;
-
-
+
+
+
274 void next (
int n,
const constant
int * shape,
const constant
size_t * strides) {
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
290 float xp1 = 1.0f +
static_cast< float > (x);
-
-
-
-
-
+
+
+
+
+
+
288 template <
typename OffsetT>
+
+
+
+
+
+
+
+
294 void next (
const constant
int *,
const constant
size_t * strides) {
+
295 offset += OffsetT(strides[0]);
+
-
-
+
+
298 void next (
int n,
const constant
int *,
const constant
size_t * strides) {
+
299 offset += n * OffsetT(strides[0]);
+
-
-
-
-
-
-
306 return as_type<uint64_t>(
-
-
+
+
-
-
-
-
311 return as_type<int64_t>(
-
-
+
-
-
-
-
-
+
+
+
+
312 template <
typename T,
typename U>
+
+
+
314 return (N + M - 1) / M;
+
-
-
-
-
-
-
+
+
+
+
+
319 float xp1 = 1.0f + x;
+
+
+
+
+
+
+
+
+
-
-
-
-
-
+
+
+
+
331 float xp1 = 1.0f +
static_cast< float > (x);
+
+
+
+
+
+
+
+
+
-
-
-
-
-
+
+
+
+
+
+
347 return as_type<uint64_t>(
+
+
-
-
-
-
-
+
+
+
+
352 return as_type<int64_t>(
+
+
-
-
-
-
-
-
+
+
-
-
-
-
-
-
344 as_type<uint2>(data), as_type<uint2>(filling), delta));
-
+
+
-
-
-
-
-
-
350 as_type<uint2>(data), as_type<uint2>(filling), delta));
-
+
+
-
-
-
-
-
355 static_cast< uint32_t
> (data),
static_cast< uint32_t
> (filling), delta);
-
+
+
-
-
-
-
-
-
-
-
-
-
+
+
-
-
-
-
-
+
+
-
-
-
-
-
+
+
+
+
+
+
385 as_type<uint2>(data), as_type<uint2>(filling), delta));
+
-
-
-
-
-
+
+
+
+
+
+
391 as_type<uint2>(data), as_type<uint2>(filling), delta));
+
-
-
-
-
-
-
+
+
+
+
+
396 static_cast< uint32_t
> (data),
static_cast< uint32_t
> (filling), delta);
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
std::vector< ptrdiff_t > stride_t
Definition pocketfft.h:103
-
-
-
static const constant U max
Definition utils.h:18
-
static const constant U finite_max
Definition utils.h:20
-
static const constant U min
Definition utils.h:19
-
static const constant U finite_min
Definition utils.h:21
+
+
+
+
+
+
+
+
static const constant U max
Definition utils.h:24
+
static const constant U finite_max
Definition utils.h:26
+
static const constant U min
Definition utils.h:25
+
static const constant U finite_min
Definition utils.h:27
+
void next(const constant int *, const constant size_t *strides)
Definition utils.h:294
+
LoopedElemToLoc(int)
Definition utils.h:292
+
OffsetT location()
Definition utils.h:302
+
void next(int n, const constant int *, const constant size_t *strides)
Definition utils.h:298
+
OffsetT location()
Definition utils.h:283
+
int dim
Definition utils.h:259
+
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:274
+
LoopedElemToLoc(int dim)
Definition utils.h:263
+
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:265
+
+
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:216
+
LoopedElemToLoc(int dim)
Definition utils.h:214
+
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:229
+
LoopedElemToLoc< DIM - 1, OffsetT, General > inner_looper
Definition utils.h:210
+
OffsetT location()
Definition utils.h:252
+
int index
Definition utils.h:212
+
OffsetT offset
Definition utils.h:211
+
int dim
Definition utils.h:209
float imag
Definition complex.h:22
float real
Definition complex.h:21
-
void next(int, const constant int *, const constant size_t *)
Definition utils.h:255
-
offset_t location(offset_t idx, const constant int *shape, const constant size_t *strides, int ndim)
Definition utils.h:257
-
void next(const constant int *, const constant size_t *)
Definition utils.h:254
-
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:247
-
void next(const constant int *, const constant size_t *strides)
Definition utils.h:238
-
void next(int n, const constant int *, const constant size_t *strides)
Definition utils.h:242
-
-
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:202
-
offset_t offset
Definition utils.h:199
-
int index
Definition utils.h:200
-
looped_elem_to_loc< dim - 1, offset_t > inner_looper
Definition utils.h:198
-
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:229
-
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:213
-template<typename T >
-void mlx::core::set_vector_bytes (CommandEncoder &enc, const std::vector< T > &vec, size_t nelems, int idx)
-
-template<typename T >
-void mlx::core::set_vector_bytes (CommandEncoder &enc, const std::vector< T > &vec, int idx)
-
+std::string mlx::core::type_to_name (const Dtype &t)
+
std::string mlx::core::type_to_name (const array &a)
MTL::Size mlx::core::get_block_dims (int dim0, int dim1, int dim2, int pow2=10)
@@ -131,6 +127,12 @@ Functions
std::string mlx::core::get_primitive_string (Primitive *primitive)
+template<typename T >
+void mlx::core::concatenate (std::string &acc, T first)
+
+template<typename T , typename... Args>
+void mlx::core::concatenate (std::string &acc, T first, Args... args)
+
diff --git a/docs/build/html/backend_2metal_2utils_8h_source.html b/docs/build/html/backend_2metal_2utils_8h_source.html
index 0795ecb30..32b2bae63 100644
--- a/docs/build/html/backend_2metal_2utils_8h_source.html
+++ b/docs/build/html/backend_2metal_2utils_8h_source.html
@@ -101,87 +101,82 @@ $(function(){ initResizable(false); });
-
11 using metal::CommandEncoder;
-
-
-
-
-
-
16 const std::vector<T>& vec,
-
-
-
19 enc->setBytes(vec.data(), nelems *
sizeof (T), idx);
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
26 const std::vector<int>& shape,
+
27 const std::vector<size_t>& strides);
+
+
+
+
+
+
33 const std::vector<int>& shape,
+
34 const std::vector<size_t>& strides,
+
+
+
+
+
38 std::string
string = os.str();
+
39 return NS::String::string(
string .c_str(), NS::UTF8StringEncoding);
+
-
-
-
-
-
-
-
+
+
+
+
+
44 std::ostringstream label;
+
45 label <<
"Stream " << index;
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
42 const std::vector<int>& shape,
-
43 const std::vector<size_t>& strides);
-
-
-
-
-
-
49 const std::vector<int>& shape,
-
50 const std::vector<size_t>& strides,
-
-
-
-
-
54 std::string
string = os.str();
-
55 return NS::String::string(
string .c_str(), NS::UTF8StringEncoding);
-
+
+
+
+
51 MTL::CommandBuffer* command_buffer,
+
+
+
54 std::ostringstream label;
+
55 if (
auto cbuf_label = command_buffer->label(); cbuf_label) {
+
56 label << cbuf_label->utf8String();
+
+
58 primitive.
print (label);
+
+
+
-
-
-
-
-
60 std::ostringstream label;
-
61 label <<
"Stream " << index;
-
-
-
-
-
+
+
+
+
-
-
67 MTL::CommandBuffer* command_buffer,
-
-
-
70 std::ostringstream label;
-
71 if (
auto cbuf_label = command_buffer->label(); cbuf_label) {
-
72 label << cbuf_label->utf8String();
-
-
74 primitive.
print (label);
-
-
-
+
+
+
-
-
-
-
+
+
70 template <
typename T,
typename ... Args>
+
+
+
Definition primitives.h:48
@@ -189,15 +184,15 @@ $(function(){ initResizable(false); });
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2=10)
-
void debug_set_primitive_buffer_label(MTL::CommandBuffer *command_buffer, Primitive &primitive)
Definition utils.h:66
-
void set_vector_bytes(CommandEncoder &enc, const std::vector< T > &vec, size_t nelems, int idx)
Definition utils.h:14
-
void debug_set_stream_queue_label(MTL::CommandQueue *queue, int index)
Definition utils.h:58
+
void debug_set_primitive_buffer_label(MTL::CommandBuffer *command_buffer, Primitive &primitive)
Definition utils.h:50
+
void concatenate(std::string &acc, T first)
Definition utils.h:66
+
void debug_set_stream_queue_label(MTL::CommandQueue *queue, int index)
Definition utils.h:42
MTL::Size get_2d_grid_dims(const std::vector< int > &shape, const std::vector< size_t > &strides)
std::string get_primitive_string(Primitive *primitive)
-
NS::String * make_string(std::ostringstream &os)
Definition utils.h:53
-
std::string type_to_name(const array &a)
+
NS::String * make_string(std::ostringstream &os)
Definition utils.h:37
+
std::string type_to_name(const Dtype &t)
-
+
-
+
Go to the source code of this file.
-METAL_FUNC bfloat16_t metal::abs (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::abs (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::acos (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::acos (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::acosh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::acosh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::asin (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::asin (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::asinh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::asinh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::atan (bfloat16_t y_over_x)
+METAL_FUNC bfloat16_t metal::atan (bfloat16_t y_over_x)
-METAL_FUNC bfloat16_t metal::atan2 (bfloat16_t y, bfloat16_t x)
+METAL_FUNC bfloat16_t metal::atan2 (bfloat16_t y, bfloat16_t x)
-METAL_FUNC bfloat16_t metal::atanh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::atanh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::ceil (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::ceil (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::cos (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::cos (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::cosh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::cosh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::cospi (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::cospi (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::divide (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::divide (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::exp (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::exp (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::exp10 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::exp10 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::exp2 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::exp2 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fabs (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fabs (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fdim (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fdim (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::floor (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::floor (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fmax (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fmax (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fmin (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fmin (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fmod (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fmod (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fract (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fract (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::frexp (bfloat16_t x, thread int &exp )
+METAL_FUNC bfloat16_t metal::frexp (bfloat16_t x, thread int &exp )
-METAL_FUNC bfloat16_t metal::ldexp (bfloat16_t x, int k)
+METAL_FUNC bfloat16_t metal::ldexp (bfloat16_t x, int k)
-METAL_FUNC bfloat16_t metal::log (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::log (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::log10 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::log10 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::log2 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::log2 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::max (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::max (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::min (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::min (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::nextafter (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::nextafter (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::pow (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::pow (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::powr (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::powr (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::rint (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::rint (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::round (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::round (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::rsqrt (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::rsqrt (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::sin (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::sin (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::sinh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::sinh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::sinpi (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::sinpi (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::sqrt (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::sqrt (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::tan (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::tan (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::tanh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::tanh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::tanpi (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::tanpi (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::trunc (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::trunc (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::abs (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::abs (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::acos (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::acos (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::acosh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::acosh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::asin (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::asin (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::asinh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::asinh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::atan (bfloat16_t y_over_x)
+METAL_FUNC bfloat16_t metal::fast::atan (bfloat16_t y_over_x)
-METAL_FUNC bfloat16_t metal::fast::atan2 (bfloat16_t y, bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::atan2 (bfloat16_t y, bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::atanh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::atanh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::ceil (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::ceil (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::cos (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::cos (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::cosh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::cosh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::cospi (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::cospi (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::divide (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::divide (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::exp (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::exp (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::exp10 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::exp10 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::exp2 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::exp2 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::fabs (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::fabs (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::fdim (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::fdim (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::floor (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::floor (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fast::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fast::fmax (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::fmax (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fast::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fast::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fast::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fast::fmin (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::fmin (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fast::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fast::fmod (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::fmod (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::fract (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::fract (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::frexp (bfloat16_t x, thread int &exp )
+METAL_FUNC bfloat16_t metal::fast::frexp (bfloat16_t x, thread int &exp )
-METAL_FUNC bfloat16_t metal::fast::ldexp (bfloat16_t x, int k)
+METAL_FUNC bfloat16_t metal::fast::ldexp (bfloat16_t x, int k)
-METAL_FUNC bfloat16_t metal::fast::log (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::log (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::log10 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::log10 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::log2 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::log2 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::max (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::max (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fast::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fast::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fast::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fast::min (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::min (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::fast::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::fast::nextafter (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::nextafter (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::pow (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::pow (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::powr (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::fast::powr (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::fast::rint (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::rint (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::round (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::round (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::rsqrt (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::rsqrt (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::sin (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::sin (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::sinh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::sinh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::sinpi (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::sinpi (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::sqrt (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::sqrt (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::tan (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::tan (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::tanh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::tanh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::tanpi (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::tanpi (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::fast::trunc (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::fast::trunc (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::abs (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::abs (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::acos (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::acos (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::acosh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::acosh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::asin (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::asin (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::asinh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::asinh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::atan (bfloat16_t y_over_x)
+METAL_FUNC bfloat16_t metal::precise::atan (bfloat16_t y_over_x)
-METAL_FUNC bfloat16_t metal::precise::atan2 (bfloat16_t y, bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::atan2 (bfloat16_t y, bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::atanh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::atanh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::ceil (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::ceil (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::cos (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::cos (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::cosh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::cosh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::cospi (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::cospi (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::divide (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::divide (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::exp (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::exp (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::exp10 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::exp10 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::exp2 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::exp2 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::fabs (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::fabs (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::fdim (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::fdim (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::floor (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::floor (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::precise::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::precise::fmax (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::fmax (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::precise::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::precise::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::precise::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::precise::fmin (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::fmin (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::precise::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::precise::fmod (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::fmod (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::fract (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::fract (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::frexp (bfloat16_t x, thread int &exp )
+METAL_FUNC bfloat16_t metal::precise::frexp (bfloat16_t x, thread int &exp )
-METAL_FUNC bfloat16_t metal::precise::ldexp (bfloat16_t x, int k)
+METAL_FUNC bfloat16_t metal::precise::ldexp (bfloat16_t x, int k)
-METAL_FUNC bfloat16_t metal::precise::log (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::log (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::log10 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::log10 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::log2 (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::log2 (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::max (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::max (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::precise::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::precise::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::precise::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::precise::min (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::min (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
+METAL_FUNC bfloat16_t metal::precise::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
-METAL_FUNC bfloat16_t metal::precise::nextafter (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::nextafter (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::pow (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::pow (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::powr (bfloat16_t x, bfloat16_t y)
+METAL_FUNC bfloat16_t metal::precise::powr (bfloat16_t x, bfloat16_t y)
-METAL_FUNC bfloat16_t metal::precise::rint (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::rint (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::round (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::round (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::rsqrt (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::rsqrt (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::sin (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::sin (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::sinh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::sinh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::sinpi (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::sinpi (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::sqrt (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::sqrt (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::tan (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::tan (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::tanh (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::tanh (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::tanpi (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::tanpi (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::precise::trunc (bfloat16_t x)
+METAL_FUNC bfloat16_t metal::precise::trunc (bfloat16_t x)
-METAL_FUNC bfloat16_t metal::simd_broadcast (bfloat16_t data, ushort broadcast_lane_id)
+METAL_FUNC bfloat16_t metal::simd_broadcast (bfloat16_t data, ushort broadcast_lane_id)
-METAL_FUNC bfloat16_t metal::simd_shuffle (bfloat16_t data, ushort simd_lane_id)
+METAL_FUNC bfloat16_t metal::simd_shuffle (bfloat16_t data, ushort simd_lane_id)
-METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
+METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
-METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta)
+METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta)
-METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
+METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
-METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta)
+METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta)
-METAL_FUNC bfloat16_t metal::simd_shuffle_down (bfloat16_t data, ushort delta)
+METAL_FUNC bfloat16_t metal::simd_shuffle_down (bfloat16_t data, ushort delta)
-METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_down (bfloat16_t data, ushort delta)
+METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_down (bfloat16_t data, ushort delta)
-METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_up (bfloat16_t data, ushort delta)
+METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_up (bfloat16_t data, ushort delta)
-METAL_FUNC bfloat16_t metal::simd_shuffle_up (bfloat16_t data, ushort delta)
+METAL_FUNC bfloat16_t metal::simd_shuffle_up (bfloat16_t data, ushort delta)
-METAL_FUNC bfloat16_t metal::simd_shuffle_xor (bfloat16_t data, ushort mask)
+METAL_FUNC bfloat16_t metal::simd_shuffle_xor (bfloat16_t data, ushort mask)
-METAL_FUNC bfloat16_t metal::simd_max (bfloat16_t data)
+METAL_FUNC bfloat16_t metal::simd_max (bfloat16_t data)
-METAL_FUNC bfloat16_t metal::simd_min (bfloat16_t data)
+METAL_FUNC bfloat16_t metal::simd_min (bfloat16_t data)
-METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_product (bfloat16_t data)
+METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_product (bfloat16_t data)
-METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_sum (bfloat16_t data)
+METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_sum (bfloat16_t data)
-METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_product (bfloat16_t data)
+METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_product (bfloat16_t data)
-METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_sum (bfloat16_t data)
+METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_sum (bfloat16_t data)
-METAL_FUNC bfloat16_t metal::simd_product (bfloat16_t data)
+METAL_FUNC bfloat16_t metal::simd_product (bfloat16_t data)
-METAL_FUNC bfloat16_t metal::simd_sum (bfloat16_t data)
+METAL_FUNC bfloat16_t metal::simd_sum (bfloat16_t data)
-METAL_FUNC bfloat16_t metal::simd_xor (bfloat16_t data)
+METAL_FUNC bfloat16_t metal::simd_xor (bfloat16_t data)
-
-
◆ bfloat16_to_uint16
-
-
-
-
-
- #define bfloat16_to_uint16
- (
- x )
-
-
-
-
-
◆ instantiate_metal_math_funcs
@@ -580,26 +557,6 @@ Functions
-
-
-
-◆ uint16_to_bfloat16
-
-
-
-
-
- #define uint16_to_bfloat16
- (
- x )
-
-
-
-
-
Value:
-
-
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat()
Definition bf16.h:64
-
diff --git a/docs/build/html/bf16__math_8h_source.html b/docs/build/html/bf16__math_8h_source.html
index c7b8985d2..e060315d9 100644
--- a/docs/build/html/bf16__math_8h_source.html
+++ b/docs/build/html/bf16__math_8h_source.html
@@ -95,408 +95,395 @@ $(function(){ initResizable(false); });
-
-
-
-
-
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
35 #define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
-
-
37 METAL_FUNC otype abs(itype x) { \
-
38 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
-
-
40 METAL_FUNC otype acos(itype x) { \
-
41 return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
-
-
43 METAL_FUNC otype acosh(itype x) { \
-
44 return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
-
-
46 METAL_FUNC otype asin(itype x) { \
-
47 return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
-
-
49 METAL_FUNC otype asinh(itype x) { \
-
50 return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
-
-
52 METAL_FUNC otype atan(itype y_over_x) { \
-
53 return static_cast<otype>( \
-
54 __metal_atan(static_cast<ctype>(y_over_x), mfast)); \
-
-
56 METAL_FUNC otype atan2(itype y, itype x) { \
-
57 return static_cast<otype>( \
-
58 __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
-
-
60 METAL_FUNC otype atanh(itype x) { \
-
61 return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
-
-
63 METAL_FUNC otype ceil(itype x) { \
-
64 return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
-
-
66 METAL_FUNC otype cos(itype x) { \
-
67 return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
-
-
69 METAL_FUNC otype cosh(itype x) { \
-
70 return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
-
-
72 METAL_FUNC otype cospi(itype x) { \
-
73 return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
-
-
75 METAL_FUNC otype divide(itype x, itype y) { \
-
76 return static_cast<otype>( \
-
77 __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
-
-
79 METAL_FUNC otype exp(itype x) { \
-
80 return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
-
-
82 METAL_FUNC otype exp10(itype x) { \
-
83 return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
-
-
85 METAL_FUNC otype exp2(itype x) { \
-
86 return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
-
-
88 METAL_FUNC otype fabs(itype x) { \
-
89 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
-
-
91 METAL_FUNC otype fdim(itype x, itype y) { \
-
92 ctype t = static_cast<ctype>(x - y); \
-
93 return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
-
-
95 METAL_FUNC otype floor(itype x) { \
-
96 return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
-
-
98 METAL_FUNC otype fma(itype x, itype y, itype z) { \
-
99 return static_cast<otype>(__metal_fma( \
-
100 static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
-
-
102 METAL_FUNC otype fmax(itype x, itype y) { \
-
103 return static_cast<otype>( \
-
104 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
-
-
106 METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
-
107 return static_cast<otype>(__metal_fmax3( \
-
108 static_cast<ctype>(x), \
-
109 static_cast<ctype>(y), \
-
110 static_cast<ctype>(z), \
-
-
-
113 METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
-
114 return static_cast<otype>(__metal_fmedian3( \
-
115 static_cast<ctype>(x), \
-
116 static_cast<ctype>(y), \
-
117 static_cast<ctype>(z), \
-
-
-
120 METAL_FUNC otype fmin(itype x, itype y) { \
-
121 return static_cast<otype>( \
-
122 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
-
-
124 METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
-
125 return static_cast<otype>(__metal_fmin3( \
-
126 static_cast<ctype>(x), \
-
127 static_cast<ctype>(y), \
-
128 static_cast<ctype>(z), \
-
-
-
131 METAL_FUNC otype fmod(itype x, itype y) { \
-
132 return static_cast<otype>( \
-
133 __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
-
-
135 METAL_FUNC otype fract(itype x) { \
-
136 return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
-
-
138 METAL_FUNC otype frexp(itype x, thread int& exp) { \
-
139 return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
-
-
141 METAL_FUNC otype ldexp(itype x, int k) { \
-
142 return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
-
-
144 METAL_FUNC otype log(itype x) { \
-
145 return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
-
-
147 METAL_FUNC otype log10(itype x) { \
-
148 return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
-
-
150 METAL_FUNC otype log2(itype x) { \
-
151 return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
-
-
153 METAL_FUNC otype max(itype x, itype y) { \
-
154 return static_cast<otype>( \
-
155 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
-
-
157 METAL_FUNC otype max3(itype x, itype y, itype z) { \
-
158 return static_cast<otype>(__metal_fmax3( \
-
159 static_cast<ctype>(x), \
-
160 static_cast<ctype>(y), \
-
161 static_cast<ctype>(z), \
-
-
-
164 METAL_FUNC otype median3(itype x, itype y, itype z) { \
-
165 return static_cast<otype>(__metal_fmedian3( \
-
166 static_cast<ctype>(x), \
-
167 static_cast<ctype>(y), \
-
168 static_cast<ctype>(z), \
-
-
-
171 METAL_FUNC otype min(itype x, itype y) { \
-
172 return static_cast<otype>( \
-
173 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
-
-
175 METAL_FUNC otype min3(itype x, itype y, itype z) { \
-
176 return static_cast<otype>(__metal_fmin3( \
-
177 static_cast<ctype>(x), \
-
178 static_cast<ctype>(y), \
-
179 static_cast<ctype>(z), \
-
-
-
182 METAL_FUNC otype nextafter(itype x, itype y) { \
-
183 return static_cast<otype>( \
-
184 __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
-
-
186 METAL_FUNC otype pow(itype x, itype y) { \
-
187 return static_cast<otype>( \
-
188 __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
-
-
190 METAL_FUNC otype powr(itype x, itype y) { \
-
191 return static_cast<otype>( \
-
192 __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
-
-
194 METAL_FUNC otype rint(itype x) { \
-
195 return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
-
-
197 METAL_FUNC otype round(itype x) { \
-
198 return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
-
-
200 METAL_FUNC otype rsqrt(itype x) { \
-
201 return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
-
-
203 METAL_FUNC otype sin(itype x) { \
-
204 return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
-
-
206 METAL_FUNC otype sinh(itype x) { \
-
207 return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
-
-
209 METAL_FUNC otype sinpi(itype x) { \
-
210 return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
-
-
212 METAL_FUNC otype sqrt(itype x) { \
-
213 return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
-
-
215 METAL_FUNC otype tan(itype x) { \
-
216 return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
-
-
218 METAL_FUNC otype tanh(itype x) { \
-
219 return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
-
-
221 METAL_FUNC otype tanpi(itype x) { \
-
222 return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
-
-
224 METAL_FUNC otype trunc(itype x) { \
-
225 return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
33 #define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
+
+
35 METAL_FUNC otype abs(itype x) { \
+
36 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
+
+
38 METAL_FUNC otype acos(itype x) { \
+
39 return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
+
+
41 METAL_FUNC otype acosh(itype x) { \
+
42 return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
+
+
44 METAL_FUNC otype asin(itype x) { \
+
45 return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
+
+
47 METAL_FUNC otype asinh(itype x) { \
+
48 return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
+
+
50 METAL_FUNC otype atan(itype y_over_x) { \
+
51 return static_cast<otype>( \
+
52 __metal_atan(static_cast<ctype>(y_over_x), mfast)); \
+
+
54 METAL_FUNC otype atan2(itype y, itype x) { \
+
55 return static_cast<otype>( \
+
56 __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
+
+
58 METAL_FUNC otype atanh(itype x) { \
+
59 return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
+
+
61 METAL_FUNC otype ceil(itype x) { \
+
62 return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
+
+
64 METAL_FUNC otype cos(itype x) { \
+
65 return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
+
+
67 METAL_FUNC otype cosh(itype x) { \
+
68 return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
+
+
70 METAL_FUNC otype cospi(itype x) { \
+
71 return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
+
+
73 METAL_FUNC otype divide(itype x, itype y) { \
+
74 return static_cast<otype>( \
+
75 __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
+
77 METAL_FUNC otype exp(itype x) { \
+
78 return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
+
+
80 METAL_FUNC otype exp10(itype x) { \
+
81 return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
+
+
83 METAL_FUNC otype exp2(itype x) { \
+
84 return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
+
+
86 METAL_FUNC otype fabs(itype x) { \
+
87 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
+
+
89 METAL_FUNC otype fdim(itype x, itype y) { \
+
90 ctype t = static_cast<ctype>(x - y); \
+
91 return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
+
+
93 METAL_FUNC otype floor(itype x) { \
+
94 return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
+
+
96 METAL_FUNC otype fma(itype x, itype y, itype z) { \
+
97 return static_cast<otype>(__metal_fma( \
+
98 static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
+
+
100 METAL_FUNC otype fmax(itype x, itype y) { \
+
101 return static_cast<otype>( \
+
102 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
+
104 METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
+
105 return static_cast<otype>(__metal_fmax3( \
+
106 static_cast<ctype>(x), \
+
107 static_cast<ctype>(y), \
+
108 static_cast<ctype>(z), \
+
+
+
111 METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
+
112 return static_cast<otype>(__metal_fmedian3( \
+
113 static_cast<ctype>(x), \
+
114 static_cast<ctype>(y), \
+
115 static_cast<ctype>(z), \
+
+
+
118 METAL_FUNC otype fmin(itype x, itype y) { \
+
119 return static_cast<otype>( \
+
120 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
+
122 METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
+
123 return static_cast<otype>(__metal_fmin3( \
+
124 static_cast<ctype>(x), \
+
125 static_cast<ctype>(y), \
+
126 static_cast<ctype>(z), \
+
+
+
129 METAL_FUNC otype fmod(itype x, itype y) { \
+
130 return static_cast<otype>( \
+
131 __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
+
133 METAL_FUNC otype fract(itype x) { \
+
134 return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
+
+
136 METAL_FUNC otype frexp(itype x, thread int& exp) { \
+
137 return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
+
+
139 METAL_FUNC otype ldexp(itype x, int k) { \
+
140 return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
+
+
142 METAL_FUNC otype log(itype x) { \
+
143 return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
+
+
145 METAL_FUNC otype log10(itype x) { \
+
146 return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
+
+
148 METAL_FUNC otype log2(itype x) { \
+
149 return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
+
+
151 METAL_FUNC otype max(itype x, itype y) { \
+
152 return static_cast<otype>( \
+
153 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
+
155 METAL_FUNC otype max3(itype x, itype y, itype z) { \
+
156 return static_cast<otype>(__metal_fmax3( \
+
157 static_cast<ctype>(x), \
+
158 static_cast<ctype>(y), \
+
159 static_cast<ctype>(z), \
+
+
+
162 METAL_FUNC otype median3(itype x, itype y, itype z) { \
+
163 return static_cast<otype>(__metal_fmedian3( \
+
164 static_cast<ctype>(x), \
+
165 static_cast<ctype>(y), \
+
166 static_cast<ctype>(z), \
+
+
+
169 METAL_FUNC otype min(itype x, itype y) { \
+
170 return static_cast<otype>( \
+
171 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
+
173 METAL_FUNC otype min3(itype x, itype y, itype z) { \
+
174 return static_cast<otype>(__metal_fmin3( \
+
175 static_cast<ctype>(x), \
+
176 static_cast<ctype>(y), \
+
177 static_cast<ctype>(z), \
+
+
+
180 METAL_FUNC otype nextafter(itype x, itype y) { \
+
181 return static_cast<otype>( \
+
182 __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
+
+
184 METAL_FUNC otype pow(itype x, itype y) { \
+
185 return static_cast<otype>( \
+
186 __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
+
188 METAL_FUNC otype powr(itype x, itype y) { \
+
189 return static_cast<otype>( \
+
190 __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
+
192 METAL_FUNC otype rint(itype x) { \
+
193 return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
+
+
195 METAL_FUNC otype round(itype x) { \
+
196 return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
+
+
198 METAL_FUNC otype rsqrt(itype x) { \
+
199 return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
+
+
201 METAL_FUNC otype sin(itype x) { \
+
202 return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
+
+
204 METAL_FUNC otype sinh(itype x) { \
+
205 return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
+
+
207 METAL_FUNC otype sinpi(itype x) { \
+
208 return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
+
+
210 METAL_FUNC otype sqrt(itype x) { \
+
211 return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
+
+
213 METAL_FUNC otype tan(itype x) { \
+
214 return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
+
+
216 METAL_FUNC otype tanh(itype x) { \
+
217 return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
+
+
219 METAL_FUNC otype tanpi(itype x) { \
+
220 return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
+
+
222 METAL_FUNC otype trunc(itype x) { \
+
223 return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
+
+
+
+
-
-
-
-
-
-
-
234 __METAL_MAYBE_FAST_MATH__);
+
+
+
+
+
232 __METAL_MAYBE_FAST_MATH__);
+
+
+
-
-
-
-
-
-
-
-
242 __METAL_FAST_MATH__);
-
-
+
+
+
+
+
240 __METAL_FAST_MATH__);
+
+
+
+
+
-
-
-
-
-
-
-
-
252 __METAL_PRECISE_MATH__);
+
+
+
+
+
250 __METAL_PRECISE_MATH__);
+
+
+
-
+
-
-
-
-
-
262 #define instantiate_metal_simd_comm_funcs( \
-
263 itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
-
-
265 METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
-
266 return ctype_to_otype( \
-
267 __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
-
-
-
270 METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
-
271 return ctype_to_otype( \
-
272 __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
-
-
-
275 METAL_FUNC otype simd_shuffle_and_fill_down( \
-
276 itype data, itype filling_data, ushort delta, ushort modulo) { \
-
277 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
-
278 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
-
-
-
281 METAL_FUNC otype simd_shuffle_and_fill_down( \
-
282 itype data, itype filling_data, ushort delta) { \
-
283 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
-
284 itype_to_ctype(data), \
-
285 itype_to_ctype(filling_data), \
-
-
287 __metal_get_simdgroup_size(ushort()))); \
-
-
-
290 METAL_FUNC otype simd_shuffle_and_fill_up( \
-
291 itype data, itype filling_data, ushort delta, ushort modulo) { \
-
292 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
-
293 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
-
-
-
296 METAL_FUNC otype simd_shuffle_and_fill_up( \
-
297 itype data, itype filling_data, ushort delta) { \
-
298 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
-
299 itype_to_ctype(data), \
-
300 itype_to_ctype(filling_data), \
-
-
302 __metal_get_simdgroup_size(ushort()))); \
-
-
-
305 METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
-
306 return ctype_to_otype( \
-
307 __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
-
-
-
310 METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
-
311 return ctype_to_otype( \
-
312 __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
-
-
-
315 METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
-
316 return ctype_to_otype( \
-
317 __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
-
-
-
320 METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
-
321 return ctype_to_otype( \
-
322 __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
-
-
-
325 METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
-
326 return ctype_to_otype( \
-
327 __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
-
-
-
-
330 #define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
-
-
332 METAL_FUNC otype simd_max(itype data) { \
-
333 return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
-
-
-
336 METAL_FUNC otype simd_min(itype data) { \
-
337 return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
-
-
-
340 METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
-
341 return static_cast<otype>( \
-
342 __metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
-
-
-
345 METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
-
346 return static_cast<otype>( \
-
347 __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
-
-
-
350 METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
-
351 return static_cast<otype>( \
-
352 __metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
-
-
-
355 METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
-
356 return static_cast<otype>( \
-
357 __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
-
-
-
360 METAL_FUNC otype simd_product(itype data) { \
-
361 return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
-
-
-
364 METAL_FUNC otype simd_sum(itype data) { \
-
365 return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
-
-
-
368 METAL_FUNC otype simd_xor(itype data) { \
-
369 return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
-
+
+
+
260 #define instantiate_metal_simd_comm_funcs( \
+
261 itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
+
+
263 METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
+
264 return ctype_to_otype( \
+
265 __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
+
+
+
268 METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
+
269 return ctype_to_otype( \
+
270 __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
+
+
+
273 METAL_FUNC otype simd_shuffle_and_fill_down( \
+
274 itype data, itype filling_data, ushort delta, ushort modulo) { \
+
275 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
+
276 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
+
+
+
279 METAL_FUNC otype simd_shuffle_and_fill_down( \
+
280 itype data, itype filling_data, ushort delta) { \
+
281 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
+
282 itype_to_ctype(data), \
+
283 itype_to_ctype(filling_data), \
+
+
285 __metal_get_simdgroup_size(ushort()))); \
+
+
+
288 METAL_FUNC otype simd_shuffle_and_fill_up( \
+
289 itype data, itype filling_data, ushort delta, ushort modulo) { \
+
290 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
+
291 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
+
+
+
294 METAL_FUNC otype simd_shuffle_and_fill_up( \
+
295 itype data, itype filling_data, ushort delta) { \
+
296 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
+
297 itype_to_ctype(data), \
+
298 itype_to_ctype(filling_data), \
+
+
300 __metal_get_simdgroup_size(ushort()))); \
+
+
+
303 METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
+
304 return ctype_to_otype( \
+
305 __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
+
+
+
308 METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
+
309 return ctype_to_otype( \
+
310 __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
+
+
+
313 METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
+
314 return ctype_to_otype( \
+
315 __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
+
+
+
318 METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
+
319 return ctype_to_otype( \
+
320 __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
+
+
+
323 METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
+
324 return ctype_to_otype( \
+
325 __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
+
+
+
+
328 #define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
+
+
330 METAL_FUNC otype simd_max(itype data) { \
+
331 return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
+
+
+
334 METAL_FUNC otype simd_min(itype data) { \
+
335 return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
+
+
+
338 METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
+
339 return static_cast<otype>( \
+
340 __metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
+
+
+
343 METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
+
344 return static_cast<otype>( \
+
345 __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
+
+
+
348 METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
+
349 return static_cast<otype>( \
+
350 __metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
+
+
+
353 METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
+
354 return static_cast<otype>( \
+
355 __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
+
+
+
358 METAL_FUNC otype simd_product(itype data) { \
+
359 return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
+
+
+
362 METAL_FUNC otype simd_sum(itype data) { \
+
363 return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
+
+
+
366 METAL_FUNC otype simd_xor(itype data) { \
+
367 return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
+
+
+
-
372 #if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
-
-
374 #define bfloat16_to_uint16(x) as_type<uint16_t>(x)
-
375 #define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
-
-
-
-
379 #define bfloat16_to_uint16(x) x.bits_
-
380 #define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
#define uint16_to_bfloat16(x)
Definition bf16_math.h:380
-
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
Definition bf16_math.h:330
-
#define bfloat16_to_uint16(x)
Definition bf16_math.h:379
-
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
Definition bf16_math.h:35
-
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
Definition bf16_math.h:262
-
-
+
+
+
+
+
+
+
+
+
+
+
+
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
Definition bf16_math.h:328
+
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
Definition bf16_math.h:33
+
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
Definition bf16_math.h:260
+
+
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
Definition binary_ops.h:8
T operator()(T x, T y)
Definition binary_ops.h:10
Definition binary_ops.h:284
diff --git a/docs/build/html/classes.html b/docs/build/html/classes.html
index 31672a75b..77d6a0d14 100644
--- a/docs/build/html/classes.html
+++ b/docs/build/html/classes.html
@@ -91,19 +91,19 @@ $(function(){ initResizable(false); });
A
-Abs Abs (mlx::core )Abs (mlx::core::detail )AccumHelper (mlx::steel )Add Add (mlx::core )Add (mlx::core::detail )add_vec (pocketfft::detail )add_vec< cmplx< T > > (pocketfft::detail )AddMM (mlx::core )AffineQuantize (mlx::core::fast )aligned_allocator (pocketfft::detail::threading )AllGather (mlx::core::distributed )Allocator (mlx::core::allocator )AllReduce (mlx::core::distributed )And Arange (mlx::core )ArcCos ArcCos (mlx::core )ArcCos (mlx::core::detail )ArcCosh ArcCosh (mlx::core )ArcCosh (mlx::core::detail )ArcSin ArcSin (mlx::core )ArcSin (mlx::core::detail )ArcSinh ArcSinh (mlx::core )ArcSinh (mlx::core::detail )ArcTan ArcTan (mlx::core )ArcTan (mlx::core::detail )ArcTan2 ArcTan2 (mlx::core )ArcTan2 (mlx::core::detail )ArcTanh ArcTanh (mlx::core )ArcTanh (mlx::core::detail )ArgPartition (mlx::core )ArgReduce (mlx::core )ArgSort (mlx::core )arr (pocketfft::detail )arr_info (pocketfft::detail )array (mlx::core )array::ArrayIterator (mlx::core )AsStrided (mlx::core )AsType (mlx::core )
+
Abs Abs (mlx::core )Abs (mlx::core::detail )AccumHelper (mlx::steel )Add Add (mlx::core )Add (mlx::core::detail )add_vec (pocketfft::detail )add_vec< cmplx< T > > (pocketfft::detail )AddMM (mlx::core )AffineQuantize (mlx::core::fast )aligned_allocator (pocketfft::detail::threading )AllGather (mlx::core::distributed )Allocator (mlx::core::allocator )AllReduce (mlx::core::distributed )And Arange (mlx::core )ArcCos ArcCos (mlx::core )ArcCos (mlx::core::detail )ArcCosh ArcCosh (mlx::core )ArcCosh (mlx::core::detail )ArcSin ArcSin (mlx::core )ArcSin (mlx::core::detail )ArcSinh ArcSinh (mlx::core )ArcSinh (mlx::core::detail )ArcTan ArcTan (mlx::core )ArcTan (mlx::core::detail )ArcTan2 ArcTan2 (mlx::core )ArcTan2 (mlx::core::detail )ArcTanh ArcTanh (mlx::core )ArcTanh (mlx::core::detail )ArgPartition (mlx::core )ArgReduce (mlx::core )ArgSort (mlx::core )arr (pocketfft::detail )arr_info (pocketfft::detail )array (mlx::core )array::ArrayIterator (mlx::core )AsStrided (mlx::core )AsType (mlx::core )AttnParams (mlx::steel )
B
-BaseMMAFrag (mlx::steel )BaseMMAFrag< T, 8, 8 > (mlx::steel )_MLX_BFloat16::bits_to_bfloat_struct BitwiseAnd BitwiseAnd (mlx::core::detail )BitwiseBinary (mlx::core )BitwiseOr BitwiseOr (mlx::core::detail )BitwiseXor BitwiseXor (mlx::core::detail )BlockLoader (mlx::steel )BlockMaskedMM (mlx::core )BlockMergeSort BlockMMA (mlx::steel )BlockSwizzle (mlx::steel )bool4_or_uint Broadcast (mlx::core )Buffer (mlx::core::allocator )
+
BaseMMAFrag (mlx::steel )BaseMMAFrag< T, 8, 8 > (mlx::steel )_MLX_BFloat16::bits_to_bfloat_struct BitwiseAnd BitwiseAnd (mlx::core::detail )BitwiseBinary (mlx::core )BitwiseOr BitwiseOr (mlx::core::detail )BitwiseXor BitwiseXor (mlx::core::detail )BlockLoader (mlx::steel )BlockLoaderT (mlx::steel )BlockMaskedMM (mlx::core )BlockMergeSort BlockMMA (mlx::steel )BlockSwizzle (mlx::steel )bool4_or_uint Broadcast (mlx::core )Buffer (mlx::core::allocator )
C
-Ceil Ceil (mlx::core )Ceil (mlx::core::detail )cfftp (pocketfft::detail )ChannelHelper (mlx::steel )ChannelHelper< 1 > (mlx::steel )ChannelHelper< 2 > (mlx::steel )ChannelHelper< 3 > (mlx::steel )ChannelHelper< 4 > (mlx::steel )Cholesky (mlx::core )cmplx (pocketfft::detail )cndarr (pocketfft::detail )CommandEncoder (mlx::core::metal )CommonAllocator (mlx::core::allocator )Compiled (mlx::core )complex128_t (mlx::core )complex64_t complex64_t (mlx::core )Concatenate (mlx::core )concurrent_queue (pocketfft::detail::threading )CommandEncoder::ConcurrentContext (mlx::core::metal )Conjugate Conjugate (mlx::core )Conjugate (mlx::core::detail )ContiguousIterator (mlx::core )Conv2DGeneralBaseInfo (mlx::steel )Conv2DGeneralJumpParams (mlx::steel )Conv2DInputBlockLoaderGeneral (mlx::steel )Conv2DInputBlockLoaderLargeFilter (mlx::steel )Conv2DInputBlockLoaderSmallChannels (mlx::steel )Conv2DInputBlockLoaderSmallFilter (mlx::steel )Conv2DWeightBlockLoader (mlx::steel )Conv2DWeightBlockLoaderGeneral (mlx::steel )Conv2DWeightBlockLoaderSmallChannels (mlx::steel )Convolution (mlx::core )Copy (mlx::core )Cos Cos (mlx::core )Cos (mlx::core::detail )Cosh Cosh (mlx::core )Cosh (mlx::core::detail )CumMax CumMin CumProd CumProd< bool > CumSum Custom (mlx::core::fast )CustomKernel (mlx::core::fast )CustomKernelShapeInfo (mlx::core::fast )CustomTransforms (mlx::core )
+
Ceil Ceil (mlx::core )Ceil (mlx::core::detail )cfftp (pocketfft::detail )ChannelHelper (mlx::steel )ChannelHelper< 1 > (mlx::steel )ChannelHelper< 2 > (mlx::steel )ChannelHelper< 3 > (mlx::steel )ChannelHelper< 4 > (mlx::steel )Cholesky (mlx::core )cmplx (pocketfft::detail )cndarr (pocketfft::detail )CommandEncoder (mlx::core::metal )CommonAllocator (mlx::core::allocator )Compiled (mlx::core )complex128_t (mlx::core )complex64_t complex64_t (mlx::core )Concatenate (mlx::core )concurrent_queue (pocketfft::detail::threading )CommandEncoder::ConcurrentContext (mlx::core::metal )Conjugate Conjugate (mlx::core )Conjugate (mlx::core::detail )Contiguous (mlx::core )ContiguousIterator (mlx::core )Conv2DGeneralBaseInfo (mlx::steel )Conv2DGeneralJumpParams (mlx::steel )Conv2DInputBlockLoaderGeneral (mlx::steel )Conv2DInputBlockLoaderLargeFilter (mlx::steel )Conv2DInputBlockLoaderSmallChannels (mlx::steel )Conv2DInputBlockLoaderSmallFilter (mlx::steel )Conv2DWeightBlockLoader (mlx::steel )Conv2DWeightBlockLoaderGeneral (mlx::steel )Conv2DWeightBlockLoaderSmallChannels (mlx::steel )Convolution (mlx::core )Copy (mlx::core )Cos Cos (mlx::core )Cos (mlx::core::detail )Cosh Cosh (mlx::core )Cosh (mlx::core::detail )CShape (mlx::steel )CumMax CumMin CumProd CumProd< bool > CumSum Custom (mlx::core::fast )CustomKernel (mlx::core::fast )CustomKernelShapeInfo (mlx::core::fast )CustomTransforms (mlx::core )
D
-array::Data (mlx::core )DefaultContiguousReduce (mlx::core )DefaultStridedReduce (mlx::core )Depends (mlx::core )Device (mlx::core )Device (mlx::core::metal )DeviceStream (mlx::core::metal )DistPrimitive (mlx::core::distributed )Divide Divide (mlx::core::detail )Divide (mlx::core )DivMod DivMod (mlx::core )Dtype (mlx::core )
+
array::Data (mlx::core )DefaultContiguousReduce (mlx::core )DefaultStridedReduce (mlx::core )Depends (mlx::core )Device (mlx::core )Device (mlx::core::metal )DeviceStream (mlx::core::metal )DistPrimitive (mlx::core::distributed )Divide Divide (mlx::core::detail )Divide (mlx::core )DivMod DivMod (mlx::core )DivOp Dtype (mlx::core )
E
-Eigh (mlx::core )Equal Equal (mlx::core::detail )Equal (mlx::core )Erf Erf (mlx::core::detail )Erf (mlx::core )ErfInv ErfInv (mlx::core::detail )ErfInv (mlx::core )Event (mlx::core )ExecC2C (pocketfft::detail )ExecDcst (pocketfft::detail )ExecHartley (pocketfft::detail )ExecR2R (pocketfft::detail )Exp Exp (mlx::core::detail )Exp (mlx::core )Expm1 Expm1 (mlx::core::detail )Expm1 (mlx::core )
+
Eigh (mlx::core )Equal Equal (mlx::core::detail )Equal (mlx::core )Erf Erf (mlx::core::detail )Erf (mlx::core )ErfInv ErfInv (mlx::core::detail )ErfInv (mlx::core )Event (mlx::core )ExecC2C (pocketfft::detail )ExecDcst (pocketfft::detail )ExecHartley (pocketfft::detail )ExecR2R (pocketfft::detail )Exp Exp (mlx::core::detail )Exp (mlx::core )Expm1 Expm1 (mlx::core::detail )Expm1 (mlx::core )ExpSubOp
F
Fence (mlx::core::metal )FFT (mlx::core )fftblue (pocketfft::detail )FileWriter (mlx::core::io )array::Flags (mlx::core )Floor Floor (mlx::core::detail )Floor (mlx::core )FloorDivide Full (mlx::core )
@@ -121,10 +121,10 @@ $(function(){ initResizable(false); });
KernelMergeSort KernelMultiBlockMergeSort KeySequence (mlx::core::random )
L
-latch (pocketfft::detail::threading )LayerNorm (mlx::core::fast )LayerNormVJP (mlx::core::fast )LeftShift LeftShift (mlx::core::detail )Less Less (mlx::core::detail )Less (mlx::core )LessEqual LessEqual (mlx::core::detail )LessEqual (mlx::core )LessThan Limits Limits< bfloat16_t > Limits< bool > Limits< complex64_t > Limits< float > Limits< half > Limits< int16_t > Limits< int32_t > Limits< int64_t > Limits< int8_t > Limits< uint16_t > Limits< uint32_t > Limits< uint64_t > Limits< uint8_t > Load (mlx::core )Log Log (mlx::core::detail )Log (mlx::core )Log10 Log10 (mlx::core::detail )Log1p Log1p (mlx::core::detail )Log1p (mlx::core )Log2 Log2 (mlx::core::detail )LogAddExp LogAddExp (mlx::core::detail )LogAddExp (mlx::core )LogicalAnd LogicalAnd (mlx::core::detail )LogicalAnd (mlx::core )LogicalNot LogicalNot (mlx::core::detail )LogicalNot (mlx::core )LogicalOr LogicalOr (mlx::core::detail )LogicalOr (mlx::core )LoopAlignment (mlx::steel )looped_elem_to_loc looped_elem_to_loc< 0, offset_t > looped_elem_to_loc< 1, offset_t >
+
latch (pocketfft::detail::threading )LayerNorm (mlx::core::fast )LayerNormVJP (mlx::core::fast )Layout2D (mlx::steel )LeftShift LeftShift (mlx::core::detail )Less Less (mlx::core::detail )Less (mlx::core )LessEqual LessEqual (mlx::core::detail )LessEqual (mlx::core )LessThan Limits Limits< bfloat16_t > Limits< bool > Limits< complex64_t > Limits< float > Limits< half > Limits< int16_t > Limits< int32_t > Limits< int64_t > Limits< int8_t > Limits< uint16_t > Limits< uint32_t > Limits< uint64_t > Limits< uint8_t > Load (mlx::core )Log Log (mlx::core::detail )Log (mlx::core )Log10 Log10 (mlx::core::detail )Log1p Log1p (mlx::core::detail )Log1p (mlx::core )Log2 Log2 (mlx::core::detail )LogAddExp LogAddExp (mlx::core::detail )LogAddExp (mlx::core )LogicalAnd LogicalAnd (mlx::core::detail )LogicalAnd (mlx::core )LogicalNot LogicalNot (mlx::core::detail )LogicalNot (mlx::core )LogicalOr LogicalOr (mlx::core::detail )LogicalOr (mlx::core )LoopAlignment (mlx::steel )LoopedElemToLoc LoopedElemToLoc< 1, OffsetT, false > LoopedElemToLoc< 1, OffsetT, true >
M
-make_void (metal )Matmul (mlx::core )Max Maximum Maximum (mlx::core::detail )Maximum (mlx::core )MetalAllocator (mlx::core::metal )Min Minimum Minimum (mlx::core::detail )Minimum (mlx::core )mlx_atomic mlx_atomic< T, enable_if_t< is_metal_atomic< T > > > MLXConvParams MLXFastAttentionParams MLXScaledDotProductAttentionParams MMATile (mlx::steel )multi_iter (pocketfft::detail )Multiply (mlx::core::detail )Multiply (mlx::core )Multiply
+
make_void (metal )Matmul (mlx::core )Max Maximum Maximum (mlx::core::detail )Maximum (mlx::core )MaxOp MetalAllocator (mlx::core::metal )Min Minimum Minimum (mlx::core::detail )Minimum (mlx::core )mlx_atomic mlx_atomic< T, enable_if_t< is_metal_atomic< T > > > MLXConvParams MMATile (mlx::steel )MulOp multi_iter (pocketfft::detail )Multiply (mlx::core::detail )Multiply (mlx::core )Multiply
N
NaNEqual (mlx::core::detail )NaNEqual ndarr (pocketfft::detail )Negative (mlx::core::detail )Negative (mlx::core )Negative NodeNamer (mlx::core )None NotEqual (mlx::core::detail )NotEqual (mlx::core )NotEqual NumberOfElements (mlx::core )
@@ -142,10 +142,10 @@ $(function(){ initResizable(false); });
RandomBits (mlx::core )Reader (mlx::core::io )BlockLoader::ReadVector (mlx::steel )ReadWriter Real (mlx::core::detail )Real (mlx::core )Real Recv (mlx::core::distributed )Reduce (mlx::core )ReductionPlan (mlx::core )Remainder (mlx::core::detail )Remainder (mlx::core )Remainder Reshape (mlx::core )ResidencySet (mlx::core::metal )RetainGraph (mlx::core::detail )rev_iter (pocketfft::detail )rfftp (pocketfft::detail )RightShift (mlx::core::detail )RightShift RMSNorm (mlx::core::fast )RMSNormVJP (mlx::core::fast )RoPE (mlx::core::fast )Round (mlx::core::detail )Round (mlx::core )Round Rsqrt (mlx::core::detail )Rsqrt
S
-ScaledDotProductAttention (mlx::core::fast )ScaleOp Scan (mlx::core )Scatter (mlx::core )Scheduler (mlx::core::scheduler )Select (mlx::core::detail )Select (mlx::core )Select Send (mlx::core::distributed )Sigmoid (mlx::core::detail )Sigmoid (mlx::core )Sigmoid Sign (mlx::core::detail )Sign (mlx::core )Sign simple_iter (pocketfft::detail )Sin (mlx::core::detail )Sin (mlx::core )Sin sincos_2pibyn (pocketfft::detail )Sinh (mlx::core::detail )Sinh (mlx::core )Sinh Slice (mlx::core )SliceUpdate (mlx::core )Softmax (mlx::core )Sort (mlx::core )Split (mlx::core )Sqrt (mlx::core::detail )Sqrt (mlx::core )Sqrt Square (mlx::core::detail )Square (mlx::core )Square StopGradient (mlx::core )Stream (mlx::core )StreamContext (mlx::core )StreamThread (mlx::core::scheduler )Subtract (mlx::core::detail )Subtract (mlx::core )Subtract Sum SVD (mlx::core )
+
ScaledDotProductAttention (mlx::core::fast )ScaleOp Scan (mlx::core )Scatter (mlx::core )Scheduler (mlx::core::scheduler )Select (mlx::core::detail )Select (mlx::core )Select Send (mlx::core::distributed )Shape2D (mlx::steel )Sigmoid (mlx::core::detail )Sigmoid (mlx::core )Sigmoid Sign (mlx::core::detail )Sign (mlx::core )Sign simple_iter (pocketfft::detail )Sin (mlx::core::detail )Sin (mlx::core )Sin sincos_2pibyn (pocketfft::detail )Sinh (mlx::core::detail )Sinh (mlx::core )Sinh Slice (mlx::core )SliceUpdate (mlx::core )Softmax (mlx::core )Sort (mlx::core )Split (mlx::core )Sqrt (mlx::core::detail )Sqrt (mlx::core )Sqrt Square (mlx::core::detail )Square (mlx::core )Square StopGradient (mlx::core )Stream (mlx::core )StreamContext (mlx::core )StreamThread (mlx::core::scheduler )SubOp Subtract (mlx::core::detail )Subtract (mlx::core )Subtract Sum SumOp SVD (mlx::core )
T
-T_dcst23 (pocketfft::detail )T_dcst4 (pocketfft::detail )T_dct1 (pocketfft::detail )T_dst1 (pocketfft::detail )Tan (mlx::core::detail )Tan (mlx::core )Tan Tanh (mlx::core::detail )Tanh (mlx::core )Tanh thread_pool (pocketfft::detail::threading )ThreadPool ThreadSort TransformAdd (mlx::steel )TransformAxpby (mlx::steel )TransformNone (mlx::steel )Transpose (mlx::core )TypeToDtype (mlx::core )
+
T_dcst23 (pocketfft::detail )T_dcst4 (pocketfft::detail )T_dct1 (pocketfft::detail )T_dst1 (pocketfft::detail )Tan (mlx::core::detail )Tan (mlx::core )Tan Tanh (mlx::core::detail )Tanh (mlx::core )Tanh thread_pool (pocketfft::detail::threading )ThreadPool ThreadSort TransformAdd (mlx::steel )TransformAxpby (mlx::steel )TransformNone (mlx::steel )TransformScale Transpose (mlx::core )TypeToDtype (mlx::core )
U
UnaryPrimitive (mlx::core )Uniform (mlx::core )util (pocketfft::detail )
diff --git a/docs/build/html/classmlx_1_1core_1_1_contiguous-members.html b/docs/build/html/classmlx_1_1core_1_1_contiguous-members.html
new file mode 100644
index 000000000..49a90d86e
--- /dev/null
+++ b/docs/build/html/classmlx_1_1core_1_1_contiguous-members.html
@@ -0,0 +1,129 @@
+
+
+
+
+
+
+
+
MLX: Member List
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+
+
+
+
+
+
+
+
This is the complete list of members for mlx::core::Contiguous , including all inherited members.
+
+ Contiguous (Stream stream, bool allow_col_major)mlx::core::Contiguous inline explicit
+ device ()mlx::core::Primitive inline
+ eval_cpu (const std::vector< array > &inputs, array &out) overridemlx::core::Contiguous virtual
+ mlx::core::UnaryPrimitive::eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitive inline virtual
+ eval_gpu (const std::vector< array > &inputs, array &out) overridemlx::core::Contiguous virtual
+ mlx::core::UnaryPrimitive::eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitive inline virtual
+ is_equivalent (const Primitive &other) const overridemlx::core::Contiguous virtual
+ jvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Contiguous virtual
+ operator= (const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
+ operator= (UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
+ mlx::core::Primitive::operator= (const Primitive &other)=deletemlx::core::Primitive
+ mlx::core::Primitive::operator= (Primitive &&other)=deletemlx::core::Primitive
+ output_shapes (const std::vector< array > &inputs) overridemlx::core::Contiguous inline virtual
+ Primitive (Stream stream)mlx::core::Primitive inline explicit
+ Primitive (const Primitive &other)=deletemlx::core::Primitive
+ Primitive (Primitive &&other)=deletemlx::core::Primitive
+ print (std::ostream &os) overridemlx::core::Contiguous inline virtual
+ stream ()mlx::core::Primitive inline
+ UnaryPrimitive (Stream stream)mlx::core::UnaryPrimitive inline explicit
+ UnaryPrimitive (const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
+ UnaryPrimitive (UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
+ vjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Contiguous virtual
+ vmap (const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Contiguous virtual
+ ~Primitive ()=defaultmlx::core::Primitive virtual
+ ~UnaryPrimitive ()=defaultmlx::core::UnaryPrimitive virtual
+
+
+
+
+
+
diff --git a/docs/build/html/classmlx_1_1core_1_1_contiguous.html b/docs/build/html/classmlx_1_1core_1_1_contiguous.html
new file mode 100644
index 000000000..98a3d0d20
--- /dev/null
+++ b/docs/build/html/classmlx_1_1core_1_1_contiguous.html
@@ -0,0 +1,481 @@
+
+
+
+
+
+
+
+
MLX: mlx::core::Contiguous Class Reference
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+
+
+
+
+
+
+
+
#include <primitives.h >
+
+
+
+
+
+
+
+
+
+
+
+ Contiguous (Stream stream , bool allow_col_major)
+
+void eval_cpu (const std::vector< array > &inputs, array &out) override
+
+void eval_gpu (const std::vector< array > &inputs, array &out) override
+
+virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
+ The primitive must know how to vectorize itself across the given axes.
+
+std::vector< array > jvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
+ The Jacobian-vector product.
+
+std::vector< array > vjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
+ The vector-Jacobian product.
+
+void print (std::ostream &os) override
+ Print the primitive.
+
+std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
+ Get the output shapes of the primitive.
+
+bool is_equivalent (const Primitive &other) const override
+ Equivalence check defaults to false unless overridden by the primitive.
+
+
+ UnaryPrimitive (Stream stream )
+ An abstract base class for a primitive with a single output.
+
+void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
+ A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
+
+void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
+
+virtual ~UnaryPrimitive ()=default
+
+ UnaryPrimitive (const UnaryPrimitive &other)=delete
+
+ UnaryPrimitive (UnaryPrimitive &&other)=delete
+
+UnaryPrimitive & operator= (const UnaryPrimitive &other)=delete
+
+UnaryPrimitive & operator= (UnaryPrimitive &&other)=delete
+
+
+ Primitive (Stream stream )
+
+const Device & device ()
+ The device the primitive will run on.
+
+const Stream & stream ()
+ The stream the primitive will run on.
+
+virtual ~Primitive ()=default
+
+ Primitive (const Primitive &other)=delete
+
+ Primitive (Primitive &&other)=delete
+
+Primitive & operator= (const Primitive &other)=delete
+
+Primitive & operator= (Primitive &&other)=delete
+
+
+
+
+
◆ Contiguous()
+
+
+
+
+
+
+
+
+ mlx::core::Contiguous::Contiguous
+ (
+ Stream stream ,
+
+
+
+
+ bool allow_col_major )
+
+
+
+
+inline explicit
+
+
+
+
+
+
+
+
+
◆ eval_cpu()
+
+
+
+
+
+
+
+
+ void mlx::core::Contiguous::eval_cpu
+ (
+ const std::vector< array > & inputs ,
+
+
+
+
+ array & out )
+
+
+
+
+override virtual
+
+
+
+
+
+
◆ eval_gpu()
+
+
+
+
+
+
+
+
+ void mlx::core::Contiguous::eval_gpu
+ (
+ const std::vector< array > & inputs ,
+
+
+
+
+ array & out )
+
+
+
+
+override virtual
+
+
+
+
+
+
◆ is_equivalent()
+
+
+
+
+
+
+
+
+ bool mlx::core::Contiguous::is_equivalent
+ (
+ const Primitive & other )
+ const
+
+
+
+
+override virtual
+
+
+
+
+
Equivalence check defaults to false unless overridden by the primitive.
+
+
Reimplemented from mlx::core::Primitive .
+
+
+
+
+
◆ jvp()
+
+
+
+
+
+
+
+
+ std::vector< array > mlx::core::Contiguous::jvp
+ (
+ const std::vector< array > & primals ,
+
+
+
+
+ const std::vector< array > & tangents ,
+
+
+
+
+ const std::vector< int > & argnums )
+
+
+
+
+override virtual
+
+
+
+
+
+
◆ output_shapes()
+
+
+
+
+
+
+
+
+ std::vector< std::vector< int > > mlx::core::Contiguous::output_shapes
+ (
+ const std::vector< array > & inputs )
+
+
+
+
+
+inline override virtual
+
+
+
+
+
Get the output shapes of the primitive.
+
This is not required to be implemented by derived classes, in which case it will throw.
+
+
Reimplemented from mlx::core::Primitive .
+
+
+
+
+
◆ print()
+
+
+
+
+
+
+
+
+ void mlx::core::Contiguous::print
+ (
+ std::ostream & os )
+
+
+
+
+
+inline override virtual
+
+
+
+
+
+
◆ vjp()
+
+
+
+
+
+
+
+
+ std::vector< array > mlx::core::Contiguous::vjp
+ (
+ const std::vector< array > & primals ,
+
+
+
+
+ const std::vector< array > & cotangents ,
+
+
+
+
+ const std::vector< int > & argnums ,
+
+
+
+
+ const std::vector< array > & outputs )
+
+
+
+
+override virtual
+
+
+
+
+
+
◆ vmap()
+
+
+
+
+
+
+
+
+ virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Contiguous::vmap
+ (
+ const std::vector< array > & inputs ,
+
+
+
+
+ const std::vector< int > & axes )
+
+
+
+
+override virtual
+
+
+
+
+
The primitive must know how to vectorize itself across the given axes.
+
The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.
+
+
Reimplemented from mlx::core::Primitive .
+
+
+
+
The documentation for this class was generated from the following file:
+
+
+
+
+
+
diff --git a/docs/build/html/classmlx_1_1core_1_1_contiguous.png b/docs/build/html/classmlx_1_1core_1_1_contiguous.png
new file mode 100644
index 000000000..13ba3febe
Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_contiguous.png differ
diff --git a/docs/build/html/classmlx_1_1core_1_1_primitive.html b/docs/build/html/classmlx_1_1core_1_1_primitive.html
index b59585a8d..0a127787f 100644
--- a/docs/build/html/classmlx_1_1core_1_1_primitive.html
+++ b/docs/build/html/classmlx_1_1core_1_1_primitive.html
@@ -380,7 +380,7 @@ Public Member Functions
Equivalence check defaults to false unless overridden by the primitive.
-Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::Arange , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::BlockMaskedMM , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Conjugate , mlx::core::Convolution , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::Divide , mlx::core::DivMod , mlx::core::Eigh , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::fast::ScaledDotProductAttention , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherMM , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::RandomBits , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::Tan , mlx::core::Tanh , mlx::core::Transpose , mlx::core::Uniform , and mlx::core::View .
+Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::Arange , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::BlockMaskedMM , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Conjugate , mlx::core::Contiguous , mlx::core::Convolution , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::Divide , mlx::core::DivMod , mlx::core::Eigh , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::fast::ScaledDotProductAttention , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherMM , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::RandomBits , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::Tan , mlx::core::Tanh , mlx::core::Transpose , mlx::core::Uniform , and mlx::core::View .
@@ -418,7 +418,7 @@ Public Member Functions
The Jacobian-vector product.
-Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::CustomTransforms , mlx::core::distributed::AllGather , mlx::core::distributed::AllReduce , mlx::core::Divide , mlx::core::DivMod , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::fast::Custom , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::Real , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::Subtract , mlx::core::Tan , mlx::core::Tanh , and mlx::core::Transpose .
+Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Contiguous , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::CustomTransforms , mlx::core::distributed::AllGather , mlx::core::distributed::AllReduce , mlx::core::Divide , mlx::core::DivMod , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::fast::Custom , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::Real , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::Subtract , mlx::core::Tan , mlx::core::Tanh , and mlx::core::Transpose .
@@ -498,7 +498,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::Ceil , mlx::core::Compiled , mlx::core::Conjugate , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::Divide , mlx::core::DivMod , mlx::core::Eigh , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::Floor , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Partition , mlx::core::Power , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Round , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Softmax , mlx::core::Sort , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::Tan , and mlx::core::Tanh .
+Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::Ceil , mlx::core::Compiled , mlx::core::Conjugate , mlx::core::Contiguous , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::Divide , mlx::core::DivMod , mlx::core::Eigh , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::Floor , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Partition , mlx::core::Power , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Round , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Softmax , mlx::core::Sort , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::Tan , and mlx::core::Tanh .
@@ -527,7 +527,7 @@ Public Member Functions
Print the primitive.
-Implemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::Arange , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::BlockMaskedMM , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Cholesky , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Conjugate , mlx::core::Convolution , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::CustomTransforms , mlx::core::Depends , mlx::core::distributed::AllReduce , mlx::core::Divide , mlx::core::DivMod , mlx::core::Eigh , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherMM , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Inverse , mlx::core::Less , mlx::core::LessEqual , mlx::core::Load , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QRF , mlx::core::QuantizedMatmul , mlx::core::RandomBits , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::SVD , mlx::core::Tan , mlx::core::Tanh , mlx::core::Transpose , mlx::core::Uniform , and mlx::core::View .
+Implemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::Arange , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::BlockMaskedMM , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Cholesky , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Conjugate , mlx::core::Contiguous , mlx::core::Convolution , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::CustomTransforms , mlx::core::Depends , mlx::core::distributed::AllReduce , mlx::core::Divide , mlx::core::DivMod , mlx::core::Eigh , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherMM , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Inverse , mlx::core::Less , mlx::core::LessEqual , mlx::core::Load , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QRF , mlx::core::QuantizedMatmul , mlx::core::RandomBits , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::SVD , mlx::core::Tan , mlx::core::Tanh , mlx::core::Transpose , mlx::core::Uniform , and mlx::core::View .
@@ -597,7 +597,7 @@ Public Member Functions
The vector-Jacobian product.
-Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::BlockMaskedMM , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Convolution , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::CustomTransforms , mlx::core::Depends , mlx::core::distributed::AllGather , mlx::core::distributed::AllReduce , mlx::core::Divide , mlx::core::DivMod , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::fast::Custom , mlx::core::fast::LayerNorm , mlx::core::fast::RMSNorm , mlx::core::fast::RoPE , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherMM , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::Subtract , mlx::core::Tan , mlx::core::Tanh , and mlx::core::Transpose .
+Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::BlockMaskedMM , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Contiguous , mlx::core::Convolution , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::CustomTransforms , mlx::core::Depends , mlx::core::distributed::AllGather , mlx::core::distributed::AllReduce , mlx::core::Divide , mlx::core::DivMod , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::fast::Custom , mlx::core::fast::LayerNorm , mlx::core::fast::RMSNorm , mlx::core::fast::RoPE , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherMM , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::Subtract , mlx::core::Tan , mlx::core::Tanh , and mlx::core::Transpose .
@@ -631,7 +631,7 @@ Public Member Functions
The primitive must know how to vectorize itself across the given axes.
The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.
-Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Cholesky , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Conjugate , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::CustomTransforms , mlx::core::distributed::AllGather , mlx::core::distributed::AllReduce , mlx::core::distributed::Send , mlx::core::Divide , mlx::core::DivMod , mlx::core::Eigh , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::fast::Custom , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Inverse , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::RandomBits , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::SVD , mlx::core::Tan , mlx::core::Tanh , mlx::core::Transpose , mlx::core::Uniform , and mlx::core::View .
+Reimplemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Cholesky , mlx::core::Compiled , mlx::core::Concatenate , mlx::core::Conjugate , mlx::core::Contiguous , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::CustomTransforms , mlx::core::distributed::AllGather , mlx::core::distributed::AllReduce , mlx::core::distributed::Send , mlx::core::Divide , mlx::core::DivMod , mlx::core::Eigh , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::fast::Custom , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Inverse , mlx::core::Less , mlx::core::LessEqual , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::RandomBits , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Split , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::SVD , mlx::core::Tan , mlx::core::Tanh , mlx::core::Transpose , mlx::core::Uniform , and mlx::core::View .
diff --git a/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html b/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html
index 8335de9db..610953b50 100644
--- a/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html
+++ b/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html
@@ -126,73 +126,74 @@ Inheritance diagram for mlx::core::UnaryPrimitive:
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -389,7 +390,7 @@ Public Member Functions
-
Implemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::Arange , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::BlockMaskedMM , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Cholesky , mlx::core::Concatenate , mlx::core::Conjugate , mlx::core::Convolution , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::Divide , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherMM , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Inverse , mlx::core::Less , mlx::core::LessEqual , mlx::core::Load , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::RandomBits , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::Tan , mlx::core::Tanh , mlx::core::Transpose , mlx::core::Uniform , and mlx::core::View .
+
Implemented in mlx::core::Abs , mlx::core::Add , mlx::core::AddMM , mlx::core::Arange , mlx::core::ArcCos , mlx::core::ArcCosh , mlx::core::ArcSin , mlx::core::ArcSinh , mlx::core::ArcTan2 , mlx::core::ArcTan , mlx::core::ArcTanh , mlx::core::ArgPartition , mlx::core::ArgReduce , mlx::core::ArgSort , mlx::core::AsStrided , mlx::core::AsType , mlx::core::BitwiseBinary , mlx::core::BlockMaskedMM , mlx::core::Broadcast , mlx::core::Ceil , mlx::core::Cholesky , mlx::core::Concatenate , mlx::core::Conjugate , mlx::core::Contiguous , mlx::core::Convolution , mlx::core::Copy , mlx::core::Cos , mlx::core::Cosh , mlx::core::Divide , mlx::core::Equal , mlx::core::Erf , mlx::core::ErfInv , mlx::core::Exp , mlx::core::Expm1 , mlx::core::FFT , mlx::core::Floor , mlx::core::Full , mlx::core::Gather , mlx::core::GatherMM , mlx::core::GatherQMM , mlx::core::Greater , mlx::core::GreaterEqual , mlx::core::Hadamard , mlx::core::Imag , mlx::core::Inverse , mlx::core::Less , mlx::core::LessEqual , mlx::core::Load , mlx::core::Log1p , mlx::core::Log , mlx::core::LogAddExp , mlx::core::LogicalAnd , mlx::core::LogicalNot , mlx::core::LogicalOr , mlx::core::Matmul , mlx::core::Maximum , mlx::core::Minimum , mlx::core::Multiply , mlx::core::Negative , mlx::core::NotEqual , mlx::core::NumberOfElements , mlx::core::Pad , mlx::core::Partition , mlx::core::Power , mlx::core::QuantizedMatmul , mlx::core::RandomBits , mlx::core::Real , mlx::core::Reduce , mlx::core::Remainder , mlx::core::Reshape , mlx::core::Round , mlx::core::Scan , mlx::core::Scatter , mlx::core::Select , mlx::core::Sigmoid , mlx::core::Sign , mlx::core::Sin , mlx::core::Sinh , mlx::core::Slice , mlx::core::SliceUpdate , mlx::core::Softmax , mlx::core::Sort , mlx::core::Sqrt , mlx::core::Square , mlx::core::StopGradient , mlx::core::Subtract , mlx::core::Tan , mlx::core::Tanh , mlx::core::Transpose , mlx::core::Uniform , and mlx::core::View .
@@ -454,7 +455,7 @@ Public Member Functions