diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo index 5a4af442a..4cd5f8315 100644 --- a/docs/build/html/.buildinfo +++ b/docs/build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 0c08faf7a4a5981ee1e4c3cab57ef3b9 +config: 6d31d3d7850f7f8959377483b35af018 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/_sources/cpp/ops.rst b/docs/build/html/_sources/cpp/ops.rst index 4d2d1404e..009a10b1e 100644 --- a/docs/build/html/_sources/cpp/ops.rst +++ b/docs/build/html/_sources/cpp/ops.rst @@ -3,4 +3,5 @@ Operations ========== - +.. doxygengroup:: ops + :content-only: diff --git a/docs/build/html/_sources/install.rst b/docs/build/html/_sources/install.rst index f34db7270..252b234e6 100644 --- a/docs/build/html/_sources/install.rst +++ b/docs/build/html/_sources/install.rst @@ -157,7 +157,10 @@ should point to the path to the built metal library. - OFF * - MLX_METAL_DEBUG - OFF - + * - MLX_BUILD_SAFETENSORS + - ON + * - MLX_BUILD_GGUF + - ON .. note:: diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.arctan2.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.arctan2.rst new file mode 100644 index 000000000..2039f68f0 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.arctan2.rst @@ -0,0 +1,6 @@ +mlx.core.arctan2 +================ + +.. currentmodule:: mlx.core + +.. autofunction:: arctan2 \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst index 9ea269f2a..e845b3cf8 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst @@ -20,6 +20,7 @@ ~array.argmax ~array.argmin ~array.astype + ~array.conj ~array.cos ~array.cummax ~array.cummin diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_and.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_and.rst new file mode 100644 index 000000000..6b8497e5c --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_and.rst @@ -0,0 +1,6 @@ +mlx.core.bitwise\_and +===================== + +.. currentmodule:: mlx.core + +.. autofunction:: bitwise_and \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_or.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_or.rst new file mode 100644 index 000000000..15eb14604 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_or.rst @@ -0,0 +1,6 @@ +mlx.core.bitwise\_or +==================== + +.. currentmodule:: mlx.core + +.. autofunction:: bitwise_or \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_xor.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_xor.rst new file mode 100644 index 000000000..ae41e5f49 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_xor.rst @@ -0,0 +1,6 @@ +mlx.core.bitwise\_xor +===================== + +.. currentmodule:: mlx.core + +.. autofunction:: bitwise_xor \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.block_sparse_mm.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.block_sparse_mm.rst new file mode 100644 index 000000000..72a9dd120 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.block_sparse_mm.rst @@ -0,0 +1,6 @@ +mlx.core.block\_sparse\_mm +========================== + +.. currentmodule:: mlx.core + +.. autofunction:: block_sparse_mm \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.conj.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.conj.rst new file mode 100644 index 000000000..f1dd8954d --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.conj.rst @@ -0,0 +1,6 @@ +mlx.core.conj +============= + +.. currentmodule:: mlx.core + +.. autofunction:: conj \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.conjugate.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.conjugate.rst new file mode 100644 index 000000000..3d3e20560 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.conjugate.rst @@ -0,0 +1,6 @@ +mlx.core.conjugate +================== + +.. currentmodule:: mlx.core + +.. autofunction:: conjugate \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.left_shift.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.left_shift.rst new file mode 100644 index 000000000..a99502501 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.left_shift.rst @@ -0,0 +1,6 @@ +mlx.core.left\_shift +==================== + +.. currentmodule:: mlx.core + +.. autofunction:: left_shift \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.metal.device_info.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.device_info.rst new file mode 100644 index 000000000..1c914a29a --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.device_info.rst @@ -0,0 +1,6 @@ +mlx.core.metal.device\_info +=========================== + +.. currentmodule:: mlx.core.metal + +.. autofunction:: device_info \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.metal.reset_peak_memory.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.reset_peak_memory.rst new file mode 100644 index 000000000..4bdbd144a --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.reset_peak_memory.rst @@ -0,0 +1,6 @@ +mlx.core.metal.reset\_peak\_memory +================================== + +.. currentmodule:: mlx.core.metal + +.. autofunction:: reset_peak_memory \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.right_shift.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.right_shift.rst new file mode 100644 index 000000000..471b61b95 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.right_shift.rst @@ -0,0 +1,6 @@ +mlx.core.right\_shift +===================== + +.. currentmodule:: mlx.core + +.. autofunction:: right_shift \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.clip_grad_norm.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.clip_grad_norm.rst new file mode 100644 index 000000000..ccd4924c5 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.clip_grad_norm.rst @@ -0,0 +1,6 @@ +mlx.optimizers.clip\_grad\_norm +=============================== + +.. currentmodule:: mlx.optimizers + +.. autofunction:: clip_grad_norm \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.utils.tree_reduce.rst b/docs/build/html/_sources/python/_autosummary/mlx.utils.tree_reduce.rst new file mode 100644 index 000000000..0bba35704 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.utils.tree_reduce.rst @@ -0,0 +1,6 @@ +mlx.utils.tree\_reduce +====================== + +.. currentmodule:: mlx.utils + +.. autofunction:: tree_reduce \ No newline at end of file diff --git a/docs/build/html/_sources/python/metal.rst b/docs/build/html/_sources/python/metal.rst index 589ec0a82..cb2cdb38e 100644 --- a/docs/build/html/_sources/python/metal.rst +++ b/docs/build/html/_sources/python/metal.rst @@ -7,8 +7,10 @@ Metal :toctree: _autosummary is_available + device_info get_active_memory get_peak_memory + reset_peak_memory get_cache_memory set_memory_limit set_cache_limit diff --git a/docs/build/html/_sources/python/ops.rst b/docs/build/html/_sources/python/ops.rst index 7795512a0..177332c49 100644 --- a/docs/build/html/_sources/python/ops.rst +++ b/docs/build/html/_sources/python/ops.rst @@ -19,6 +19,7 @@ Operations arcsin arcsinh arctan + arctan2 arctanh argmax argmin @@ -28,11 +29,17 @@ Operations atleast_1d atleast_2d atleast_3d - broadcast_to + bitwise_and + bitwise_or + bitwise_xor block_masked_mm + block_sparse_mm + broadcast_to ceil clip concatenate + conj + conjugate convolve conv1d conv2d @@ -69,6 +76,7 @@ Operations isnan isneginf isposinf + left_shift less less_equal linspace @@ -105,6 +113,7 @@ Operations reciprocal repeat reshape + right_shift round rsqrt save diff --git a/docs/build/html/_sources/python/optimizers.rst b/docs/build/html/_sources/python/optimizers.rst index f437ddc15..84ab933ac 100644 --- a/docs/build/html/_sources/python/optimizers.rst +++ b/docs/build/html/_sources/python/optimizers.rst @@ -1,5 +1,7 @@ .. _optimizers: +.. currentmodule:: mlx.optimizers + Optimizers ========== @@ -34,3 +36,8 @@ model's parameters and the **optimizer state**. optimizers/optimizer optimizers/common_optimizers optimizers/schedulers + +.. autosummary:: + :toctree: _autosummary + + clip_grad_norm diff --git a/docs/build/html/_sources/python/tree_utils.rst b/docs/build/html/_sources/python/tree_utils.rst index dbd0ebce9..6dc60b47d 100644 --- a/docs/build/html/_sources/python/tree_utils.rst +++ b/docs/build/html/_sources/python/tree_utils.rst @@ -20,3 +20,4 @@ return python trees will be using the default python ``dict``, ``list`` and tree_unflatten tree_map tree_map_with_path + tree_reduce diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js index b217302fd..607aaea4c 100644 --- a/docs/build/html/_static/documentation_options.js +++ b/docs/build/html/_static/documentation_options.js @@ -1,5 +1,5 @@ const DOCUMENTATION_OPTIONS = { - VERSION: '0.12.0', + VERSION: '0.13.0', LANGUAGE: 'en', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/build/html/allocator_8h.html b/docs/build/html/allocator_8h.html new file mode 100644 index 000000000..8e3b1e763 --- /dev/null +++ b/docs/build/html/allocator_8h.html @@ -0,0 +1,124 @@ + + + + + + + +MLX: mlx/allocator.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Namespaces | +Functions
+
allocator.h File Reference
+
+
+
#include <cstdlib>
+
+

Go to the source code of this file.

+ + + + + + + + +

+Classes

class  mlx::core::allocator::Buffer
 
class  mlx::core::allocator::Allocator
 
class  mlx::core::allocator::CommonAllocator
 
+ + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::allocator
 
+ + + + + + + + + +

+Functions

Buffer mlx::core::allocator::malloc (size_t size)
 
void mlx::core::allocator::free (Buffer buffer)
 
Buffer mlx::core::allocator::malloc_or_wait (size_t size)
 
Allocatormlx::core::allocator::allocator ()
 
+
+ + + + diff --git a/docs/build/html/allocator_8h_source.html b/docs/build/html/allocator_8h_source.html new file mode 100644 index 000000000..917df829b --- /dev/null +++ b/docs/build/html/allocator_8h_source.html @@ -0,0 +1,191 @@ + + + + + + + +MLX: mlx/allocator.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
allocator.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <cstdlib>
+
6
+
+ +
8
+
9// Simple wrapper around buffer pointers
+
10// WARNING: Only Buffer objects constructed from and those that wrap
+
11// raw pointers from mlx::allocator are supported.
+
+
12class Buffer {
+
13 private:
+
14 void* ptr_;
+
15
+
16 public:
+
17 Buffer(void* ptr) : ptr_(ptr) {};
+
18
+
19 // Get the raw data pointer from the buffer
+
20 void* raw_ptr();
+
21
+
22 // Get the buffer pointer from the buffer
+
+
23 const void* ptr() const {
+
24 return ptr_;
+
25 };
+
+
+
26 void* ptr() {
+
27 return ptr_;
+
28 };
+
+
29};
+
+
30
+
31Buffer malloc(size_t size);
+
32
+
33void free(Buffer buffer);
+
34
+
35// Wait for running tasks to finish and free up memory
+
36// if allocation fails
+ +
38
+
+
39class Allocator {
+
41 public:
+
42 virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
+
43 virtual void free(Buffer buffer) = 0;
+
44
+
45 Allocator() = default;
+
46 Allocator(const Allocator& other) = delete;
+
47 Allocator(Allocator&& other) = delete;
+
48 Allocator& operator=(const Allocator& other) = delete;
+
49 Allocator& operator=(Allocator&& other) = delete;
+
50 virtual ~Allocator() = default;
+
51};
+
+
52
+ +
54
+
+
55class CommonAllocator : public Allocator {
+
57 public:
+
58 virtual Buffer malloc(size_t size, bool allow_swap = false) override;
+
59 virtual void free(Buffer buffer) override;
+
60
+
61 private:
+
62 CommonAllocator() = default;
+ +
64};
+
+
65
+
66} // namespace mlx::core::allocator
+
+
Definition allocator.h:39
+
Allocator & operator=(const Allocator &other)=delete
+
Allocator & operator=(Allocator &&other)=delete
+ +
Allocator(Allocator &&other)=delete
+ +
virtual Buffer malloc(size_t size, bool allow_swap=false)=0
Abstract base class for a memory allocator.
+
Allocator(const Allocator &other)=delete
+
virtual void free(Buffer buffer)=0
+
Definition allocator.h:12
+ +
const void * ptr() const
Definition allocator.h:23
+
Buffer(void *ptr)
Definition allocator.h:17
+
void * ptr()
Definition allocator.h:26
+
Definition allocator.h:55
+
virtual Buffer malloc(size_t size, bool allow_swap=false) override
A general CPU allocator.
+
virtual void free(Buffer buffer) override
+ +
Definition allocator.h:7
+
Buffer malloc(size_t size)
+
void free(Buffer buffer)
+
Buffer malloc_or_wait(size_t size)
+
Allocator & allocator()
+
+ + + + diff --git a/docs/build/html/annotated.html b/docs/build/html/annotated.html new file mode 100644 index 000000000..220efdb5b --- /dev/null +++ b/docs/build/html/annotated.html @@ -0,0 +1,445 @@ + + + + + + + +MLX: Class List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + +
+ +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ +
+
Class List
+
+
+
Here are the classes, structs, unions and interfaces with brief descriptions:
+
[detail level 12345]
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
 Nmetal
 Nmlx
 NMPS
 Npocketfft
 C_MLX_BFloat16
 CAbs
 CAdd
 CAnd
 CArcCos
 CArcCosh
 CArcSin
 CArcSinh
 CArcTan
 CArcTan2
 CArcTanh
 CBitwiseAnd
 CBitwiseOr
 CBitwiseXor
 Cbool4_or_uint
 CCeil
 Ccomplex64_t
 CConjugate
 CCos
 CCosh
 CDivide
 CEqual
 CErf
 CErfInv
 CExp
 CExpm1
 CFloor
 CGreater
 CGreaterEqual
 CIndices
 CLeftShift
 CLess
 CLessEqual
 CLimits
 CLimits< bfloat16_t >
 CLimits< bool >
 CLimits< float >
 CLimits< half >
 CLimits< int16_t >
 CLimits< int32_t >
 CLimits< int64_t >
 CLimits< int8_t >
 CLimits< uint16_t >
 CLimits< uint32_t >
 CLimits< uint64_t >
 CLimits< uint8_t >
 CLog
 CLog10
 CLog1p
 CLog2
 CLogAddExp
 CLogicalAnd
 CLogicalNot
 CLogicalOr
 CMax
 CMaximum
 CMin
 CMinimum
 Cmlx_atomic
 Cmlx_atomic< T, enable_if_t< is_metal_atomic< T > > >
 CMLXConvParams
 CMLXScaledDotProductAttentionParams
 CMultiply
 CNaNEqual
 CNegative
 CNone
 CNotEqual
 COr
 CPower
 CProd
 CRemainder
 CRightShift
 CRound
 CRsqrt
 CSelect
 CSigmoid
 CSign
 CSin
 CSinh
 CSqrt
 CSquare
 CSubtract
 CSum
 CTan
 CTanh
+
+
+ + + + diff --git a/docs/build/html/arange_8h.html b/docs/build/html/arange_8h.html new file mode 100644 index 000000000..804fdf88e --- /dev/null +++ b/docs/build/html/arange_8h.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: mlx/backend/common/arange.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces | +Functions
+
arange.h File Reference
+
+
+
#include "mlx/allocator.h"
+#include "mlx/array.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Functions

void mlx::core::arange (const std::vector< array > &inputs, array &out, double start, double step)
 
+
+ + + + diff --git a/docs/build/html/arange_8h_source.html b/docs/build/html/arange_8h_source.html new file mode 100644 index 000000000..2f042106a --- /dev/null +++ b/docs/build/html/arange_8h_source.html @@ -0,0 +1,192 @@ + + + + + + + +MLX: mlx/backend/common/arange.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
arange.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/allocator.h"
+
6#include "mlx/array.h"
+
7
+
8namespace mlx::core {
+
9
+
10namespace {
+
11
+
12template <typename T>
+
13void arange(T start, T next, array& out, size_t size) {
+
14 auto ptr = out.data<T>();
+
15 auto step_size = next - start;
+
16 for (int i = 0; i < size; ++i) {
+
17 ptr[i] = start;
+
18 start += step_size;
+
19 }
+
20}
+
21
+
22} // namespace
+
23
+
+
24void arange(
+
25 const std::vector<array>& inputs,
+
26 array& out,
+
27 double start,
+
28 double step) {
+
29 assert(inputs.size() == 0);
+ +
31 switch (out.dtype()) {
+
32 case bool_:
+
33 throw std::runtime_error("Bool type unsupported for arange.");
+
34 break;
+
35 case uint8:
+
36 arange<uint8_t>(start, start + step, out, out.size());
+
37 break;
+
38 case uint16:
+
39 arange<uint16_t>(start, start + step, out, out.size());
+
40 break;
+
41 case uint32:
+
42 arange<uint32_t>(start, start + step, out, out.size());
+
43 break;
+
44 case uint64:
+
45 arange<uint64_t>(start, start + step, out, out.size());
+
46 break;
+
47 case int8:
+
48 arange<int8_t>(start, start + step, out, out.size());
+
49 break;
+
50 case int16:
+
51 arange<int16_t>(start, start + step, out, out.size());
+
52 break;
+
53 case int32:
+
54 arange<int32_t>(start, start + step, out, out.size());
+
55 break;
+
56 case int64:
+
57 arange<int64_t>(start, start + step, out, out.size());
+
58 break;
+
59 case float16:
+
60 arange<float16_t>(start, start + step, out, out.size());
+
61 break;
+
62 case float32:
+
63 arange<float>(start, start + step, out, out.size());
+
64 break;
+
65 case bfloat16:
+
66 arange<bfloat16_t>(start, start + step, out, out.size());
+
67 break;
+
68 case complex64:
+
69 arange<complex64_t>(start, start + step, out, out.size());
+
70 break;
+
71 }
+
72}
+
+
73
+
74} // namespace mlx::core
+ + +
BufferHolder * next
Definition allocator.h:37
+
Definition array.h:20
+
size_t nbytes() const
The number of bytes in the array.
Definition array.h:89
+
size_t size() const
The number of elements in the array.
Definition array.h:84
+
void set_data(allocator::Buffer buffer, deleter_t d=allocator::free)
+
Dtype dtype() const
Get the arrays data type.
Definition array.h:127
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition arange.h:24
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+
+ + + + diff --git a/docs/build/html/array_8h.html b/docs/build/html/array_8h.html new file mode 100644 index 000000000..bdfcf22d8 --- /dev/null +++ b/docs/build/html/array_8h.html @@ -0,0 +1,138 @@ + + + + + + + +MLX: mlx/array.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Namespaces | +Typedefs | +Variables
+
array.h File Reference
+
+
+
#include <algorithm>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <vector>
+#include "mlx/allocator.h"
+#include "mlx/dtype.h"
+#include "mlx/event.h"
+
+

Go to the source code of this file.

+ + + + + + + + + + +

+Classes

class  mlx::core::array
 
struct  mlx::core::array::ArrayIterator
 
struct  mlx::core::array::Data
 
struct  mlx::core::array::Flags
 
+ + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + + + + +

+Typedefs

using mlx::core::deleter_t = std::function<void(allocator::Buffer)>
 
template<typename... T>
using mlx::core::enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>
 
+ + + + + + + +

+Variables

template<typename T >
constexpr bool mlx::core::is_array_v
 
template<typename... T>
constexpr bool mlx::core::is_arrays_v = (is_array_v<T> && ...)
 
+
+ + + + diff --git a/docs/build/html/array_8h_source.html b/docs/build/html/array_8h_source.html new file mode 100644 index 000000000..4b15bc560 --- /dev/null +++ b/docs/build/html/array_8h_source.html @@ -0,0 +1,842 @@ + + + + + + + +MLX: mlx/array.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
array.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2#pragma once
+
3
+
4#include <algorithm>
+
5#include <cstdint>
+
6#include <functional>
+
7#include <memory>
+
8#include <vector>
+
9
+
10#include "mlx/allocator.h"
+
11#include "mlx/dtype.h"
+
12#include "mlx/event.h"
+
13
+
14namespace mlx::core {
+
15
+
16// Forward declaration
+
17class Primitive;
+
18using deleter_t = std::function<void(allocator::Buffer)>;
+
19
+
+
20class array {
+
21 /* An array is really a node in a graph. It contains a shared ArrayDesc
+
22 * object */
+
23
+
24 public:
+
26 template <typename T>
+
27 explicit array(T val, Dtype dtype = TypeToDtype<T>());
+
28
+
29 /* Special case since std::complex can't be implicitly converted to other
+
30 * types. */
+
31 explicit array(const std::complex<float>& val, Dtype dtype = complex64);
+
32
+
33 template <typename It>
+
34 array(
+
35 It data,
+
36 std::vector<int> shape,
+
37 Dtype dtype =
+
38 TypeToDtype<typename std::iterator_traits<It>::value_type>());
+
39
+
40 template <typename T>
+
41 array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
+
42
+
43 /* Special case so empty lists default to float32. */
+
44 array(std::initializer_list<float> data);
+
45
+
46 /* Special case so array({}, type) is an empty array. */
+
47 array(std::initializer_list<int> data, Dtype dtype);
+
48
+
49 template <typename T>
+
50 array(
+
51 std::initializer_list<T> data,
+
52 std::vector<int> shape,
+ +
54
+
55 /* Build an array from a buffer */
+ + +
58 std::vector<int> shape,
+ +
60 deleter_t deleter = allocator::free);
+
61
+
63 array& operator=(const array& other) && = delete;
+
64 array& operator=(array&& other) && = delete;
+
65
+
67 array& operator=(array&& other) & = default;
+
68 array(const array& other) = default;
+
69 array(array&& other) = default;
+
70
+
+
71 array& operator=(const array& other) & {
+
72 if (this->id() != other.id()) {
+
73 this->array_desc_ = other.array_desc_;
+
74 }
+
75 return *this;
+
76 };
+
+
77
+
+
79 size_t itemsize() const {
+
80 return size_of(dtype());
+
81 };
+
+
82
+
+
84 size_t size() const {
+
85 return array_desc_->size;
+
86 };
+
+
87
+
+
89 size_t nbytes() const {
+
90 return size() * itemsize();
+
91 };
+
+
92
+
+
94 size_t ndim() const {
+
95 return array_desc_->shape.size();
+
96 };
+
+
97
+
+
99 const std::vector<int>& shape() const {
+
100 return array_desc_->shape;
+
101 };
+
+
102
+
+
108 int shape(int dim) const {
+
109 return shape().at(dim < 0 ? dim + ndim() : dim);
+
110 };
+
+
111
+
+
113 const std::vector<size_t>& strides() const {
+
114 return array_desc_->strides;
+
115 };
+
+
116
+
+
122 size_t strides(int dim) const {
+
123 return strides().at(dim < 0 ? dim + ndim() : dim);
+
124 };
+
+
125
+
+
127 Dtype dtype() const {
+
128 return array_desc_->dtype;
+
129 };
+
+
130
+
132 void eval();
+
133
+
135 template <typename T>
+
136 T item();
+
137
+
138 template <typename T>
+
139 T item() const;
+
140
+
+ +
142 using iterator_category = std::random_access_iterator_tag;
+
143 using difference_type = size_t;
+
144 using value_type = const array;
+ +
146
+
147 explicit ArrayIterator(const array& arr, int idx = 0);
+
148
+ +
150
+
+ +
152 idx += diff;
+
153 return *this;
+
154 }
+
+
155
+
+ +
157 idx++;
+
158 return *this;
+
159 }
+
+
160
+
+
161 friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
+
162 return a.arr.id() == b.arr.id() && a.idx == b.idx;
+
163 };
+
+
+
164 friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
+
165 return !(a == b);
+
166 };
+
+
167
+
168 private:
+
169 const array& arr;
+
170 int idx;
+
171 };
+
+
172
+
+ +
174 return ArrayIterator(*this);
+
175 }
+
+
+ +
177 return ArrayIterator(*this, shape(0));
+
178 }
+
+
179
+ +
187 std::vector<int> shape,
+
188 Dtype dtype,
+
189 std::shared_ptr<Primitive> primitive,
+
190 std::vector<array> inputs);
+
191
+
192 static std::vector<array> make_arrays(
+
193 std::vector<std::vector<int>> shapes,
+
194 const std::vector<Dtype>& dtypes,
+
195 const std::shared_ptr<Primitive>& primitive,
+
196 const std::vector<array>& inputs);
+
197
+
+
199 std::uintptr_t id() const {
+
200 return reinterpret_cast<std::uintptr_t>(array_desc_.get());
+
201 }
+
+
202
+
+
204 std::uintptr_t primitive_id() const {
+
205 return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
+
206 }
+
+
207
+
+
208 struct Data {
+ + + +
213 // Not copyable
+
214 Data(const Data& d) = delete;
+
215 Data& operator=(const Data& d) = delete;
+
+ +
217 d(buffer);
+
218 }
+
+
219 };
+
+
220
+
+
221 struct Flags {
+
222 // True if there are no gaps in the underlying data. Each item
+
223 // in the underlying data buffer belongs to at least one index.
+
224 bool contiguous : 1;
+
225
+ + +
228 };
+
+
229
+
+ +
232 return *(array_desc_->primitive);
+
233 };
+
+
234
+
+
236 std::shared_ptr<Primitive>& primitive_ptr() const {
+
237 return array_desc_->primitive;
+
238 };
+
+
239
+
+
241 bool has_primitive() const {
+
242 return array_desc_->primitive != nullptr;
+
243 };
+
+
244
+
+
246 const std::vector<array>& inputs() const {
+
247 return array_desc_->inputs;
+
248 };
+
+
249
+
+
250 std::vector<array>& inputs() {
+
251 return array_desc_->inputs;
+
252 }
+
+
253
+
+
255 bool is_donatable() const {
+
256 return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
+
257 }
+
+
258
+
+
260 const std::vector<array>& siblings() const {
+
261 return array_desc_->siblings;
+
262 };
+
+
263
+
+
265 std::vector<array>& siblings() {
+
266 return array_desc_->siblings;
+
267 };
+
+
268
+
+
269 void set_siblings(std::vector<array> siblings, uint16_t position) {
+
270 array_desc_->siblings = std::move(siblings);
+
271 array_desc_->position = position;
+
272 }
+
+
273
+
+
276 std::vector<array> outputs() const {
+
277 auto idx = array_desc_->position;
+
278 std::vector<array> outputs;
+
279 outputs.reserve(siblings().size() + 1);
+
280 outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
+
281 outputs.push_back(*this);
+
282 outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
+
283 return outputs;
+
284 };
+
+
285
+
287 void detach();
+
288
+
+
290 const Flags& flags() const {
+
291 return array_desc_->flags;
+
292 };
+
+
293
+
+
295 size_t data_size() const {
+
296 return array_desc_->data_size;
+
297 };
+
+
298
+
+ +
300 return array_desc_->data->buffer;
+
301 };
+
+
+
302 const allocator::Buffer& buffer() const {
+
303 return array_desc_->data->buffer;
+
304 };
+
+
305
+
306 // Return a copy of the shared pointer
+
307 // to the array::Data struct
+
+
308 std::shared_ptr<Data> data_shared_ptr() const {
+
309 return array_desc_->data;
+
310 }
+
+
311 // Return a raw pointer to the arrays data
+
312 template <typename T>
+
+
313 T* data() {
+
314 return static_cast<T*>(array_desc_->data_ptr);
+
315 };
+
+
316
+
317 template <typename T>
+
+
318 const T* data() const {
+
319 return static_cast<T*>(array_desc_->data_ptr);
+
320 };
+
+
321
+ +
323
+
+
324 bool is_available() const {
+
325 return status() == Status::available;
+
326 }
+
+
+
327 const Status status() const {
+
328 return array_desc_->status;
+
329 }
+
+
330
+
+
331 void set_status(Status s) const {
+
332 array_desc_->status = s;
+
333 }
+
+
334
+
335 // Get the array's shared event
+
+
336 Event& event() const {
+
337 return array_desc_->event;
+
338 }
+
+
339
+
340 // Attach an event to a not yet evaluated array
+
+
341 void attach_event(Event e) const {
+
342 array_desc_->event = std::move(e);
+
343 }
+
+
344
+
345 // Mark the array as a tracer array (true) or not.
+
+ +
347 array_desc_->is_tracer = is_tracer;
+
348 }
+
+
349 // Check if the array is a tracer array
+
350 bool is_tracer() const;
+
351
+ +
353
+ + +
356 size_t data_size,
+
357 std::vector<size_t> strides,
+
358 Flags flags,
+ +
360
+ +
362 const array& other,
+
363 const std::vector<size_t>& strides,
+
364 Flags flags,
+
365 size_t data_size,
+
366 size_t offset = 0);
+
367
+
368 void copy_shared_buffer(const array& other);
+
369
+ +
371 array other,
+
372 const std::vector<size_t>& strides,
+
373 Flags flags,
+
374 size_t data_size,
+
375 size_t offset = 0);
+
376
+ +
378
+
+
379 void overwrite_descriptor(const array& other) {
+
380 array_desc_ = other.array_desc_;
+
381 }
+
+
382
+ +
384
+
385 private:
+
386 // Initialize the arrays data
+
387 template <typename It>
+
388 void init(const It src);
+
389
+
390 struct ArrayDesc {
+
391 std::vector<int> shape;
+
392 std::vector<size_t> strides;
+
393 size_t size;
+
394 Dtype dtype;
+
395 std::shared_ptr<Primitive> primitive;
+
396
+
397 Status status;
+
398
+
399 // An event on the array used for synchronization
+
400 Event event;
+
401
+
402 // Indicates an array is being used in a graph transform
+
403 // and should not be detached from the graph
+
404 bool is_tracer{false};
+
405
+
406 // This is a shared pointer so that *different* arrays
+
407 // can share the underlying data buffer.
+
408 std::shared_ptr<Data> data;
+
409
+
410 // Properly offset data pointer
+
411 void* data_ptr{nullptr};
+
412
+
413 // The size in elements of the data buffer the array accesses
+
414 // This can be different than the actual size of the array if it
+
415 // has been broadcast or irregularly strided.
+
416 size_t data_size;
+
417
+
418 // Contains useful meta data about the array
+
419 Flags flags;
+
420
+
421 std::vector<array> inputs;
+
422 // An array to keep track of the siblings from a multi-output
+
423 // primitive.
+
424 std::vector<array> siblings;
+
425 // The arrays position in the output list
+
426 uint32_t position{0};
+
427
+
428 explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
+
429
+
430 explicit ArrayDesc(
+
431 std::vector<int> shape,
+
432 Dtype dtype,
+
433 std::shared_ptr<Primitive> primitive,
+
434 std::vector<array> inputs);
+
435
+
436 ~ArrayDesc();
+
437
+
438 private:
+
439 // Initialize size, strides, and other metadata
+
440 void init();
+
441 };
+
442
+
443 // The ArrayDesc contains the details of the materialized array including the
+
444 // shape, strides, the data type. It also includes
+
445 // the primitive which knows how to compute the array's data from its inputs
+
446 // and the list of array's inputs for the primitive.
+
447 std::shared_ptr<ArrayDesc> array_desc_;
+
448};
+
+
449
+
450template <typename T>
+
+
451array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
+
452 : array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
+
453 init(&val);
+
454}
+
+
455
+
456template <typename It>
+
+ +
458 It data,
+
459 std::vector<int> shape,
+
460 Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
+
461 array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
+
462 init(data);
+
463}
+
+
464
+
465template <typename T>
+
+ +
467 std::initializer_list<T> data,
+
468 Dtype dtype /* = TypeToDtype<T>() */)
+
469 : array_desc_(std::make_shared<ArrayDesc>(
+
470 std::vector<int>{static_cast<int>(data.size())},
+
471 dtype)) {
+
472 init(data.begin());
+
473}
+
+
474
+
475template <typename T>
+
+ +
477 std::initializer_list<T> data,
+
478 std::vector<int> shape,
+
479 Dtype dtype /* = TypeToDtype<T>() */)
+
480 : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
+
481 if (data.size() != size()) {
+
482 throw std::invalid_argument(
+
483 "Data size and provided shape mismatch in array construction.");
+
484 }
+
485 init(data.begin());
+
486}
+
+
487
+
488template <typename T>
+
+ +
490 if (size() != 1) {
+
491 throw std::invalid_argument("item can only be called on arrays of size 1.");
+
492 }
+
493 eval();
+
494 return *data<T>();
+
495}
+
+
496
+
497template <typename T>
+
+
498T array::item() const {
+
499 if (size() != 1) {
+
500 throw std::invalid_argument("item can only be called on arrays of size 1.");
+
501 }
+
502 if (status() == Status::unscheduled) {
+
503 throw std::invalid_argument(
+
504 "item() const can only be called on evaled arrays");
+
505 }
+
506 const_cast<array*>(this)->eval();
+
507 return *data<T>();
+
508}
+
+
509
+
510template <typename It>
+
511void array::init(It src) {
+ +
513 switch (dtype()) {
+
514 case bool_:
+
515 std::copy(src, src + size(), data<bool>());
+
516 break;
+
517 case uint8:
+
518 std::copy(src, src + size(), data<uint8_t>());
+
519 break;
+
520 case uint16:
+
521 std::copy(src, src + size(), data<uint16_t>());
+
522 break;
+
523 case uint32:
+
524 std::copy(src, src + size(), data<uint32_t>());
+
525 break;
+
526 case uint64:
+
527 std::copy(src, src + size(), data<uint64_t>());
+
528 break;
+
529 case int8:
+
530 std::copy(src, src + size(), data<int8_t>());
+
531 break;
+
532 case int16:
+
533 std::copy(src, src + size(), data<int16_t>());
+
534 break;
+
535 case int32:
+
536 std::copy(src, src + size(), data<int32_t>());
+
537 break;
+
538 case int64:
+
539 std::copy(src, src + size(), data<int64_t>());
+
540 break;
+
541 case float16:
+
542 std::copy(src, src + size(), data<float16_t>());
+
543 break;
+
544 case float32:
+
545 std::copy(src, src + size(), data<float>());
+
546 break;
+
547 case bfloat16:
+
548 std::copy(src, src + size(), data<bfloat16_t>());
+
549 break;
+
550 case complex64:
+
551 std::copy(src, src + size(), data<complex64_t>());
+
552 break;
+
553 }
+
554}
+
555
+
556/* Utilities for determining whether a template parameter is array. */
+
557template <typename T>
+
558inline constexpr bool is_array_v =
+
559 std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
+
560
+
561template <typename... T>
+
562inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
+
563
+
564template <typename... T>
+
565using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
+
566
+
567} // namespace mlx::core
+ +
Definition event.h:11
+
Definition primitives.h:48
+
Definition allocator.h:12
+
Definition array.h:20
+
void attach_event(Event e) const
Definition array.h:341
+
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:290
+
Event & event() const
Definition array.h:336
+
static std::vector< array > make_arrays(std::vector< std::vector< int > > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)
+
const std::vector< size_t > & strides() const
The strides of the array.
Definition array.h:113
+
Status
Definition array.h:322
+
@ available
Definition array.h:322
+
@ unscheduled
Definition array.h:322
+
@ scheduled
Definition array.h:322
+
void set_data(allocator::Buffer buffer, size_t data_size, std::vector< size_t > strides, Flags flags, deleter_t d=allocator::free)
+
void eval()
Evaluate the array.
+
void copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
+
const std::vector< array > & inputs() const
The array's inputs.
Definition array.h:246
+
array(const array &other)=default
+
std::vector< array > outputs() const
The outputs of the array's primitive (i.e.
Definition array.h:276
+ +
size_t nbytes() const
The number of bytes in the array.
Definition array.h:89
+
void move_shared_buffer(array other)
+
array(std::initializer_list< float > data)
+
bool is_donatable() const
True indicates the arrays buffer is safe to reuse.
Definition array.h:255
+
const std::vector< int > & shape() const
The shape of the array as a vector of integers.
Definition array.h:99
+
std::shared_ptr< Primitive > & primitive_ptr() const
A shared pointer to the array's primitive.
Definition array.h:236
+
int shape(int dim) const
Get the size of the corresponding dimension.
Definition array.h:108
+
size_t ndim() const
The number of dimensions of the array.
Definition array.h:94
+
size_t size() const
The number of elements in the array.
Definition array.h:84
+
array(allocator::Buffer data, std::vector< int > shape, Dtype dtype, deleter_t deleter=allocator::free)
+
array & operator=(array &&other) &&=delete
+
array & operator=(const array &other) &
Definition array.h:71
+
ArrayIterator end() const
Definition array.h:176
+
array(std::initializer_list< int > data, Dtype dtype)
+
void set_data(allocator::Buffer buffer, deleter_t d=allocator::free)
+
const allocator::Buffer & buffer() const
Definition array.h:302
+
void set_status(Status s) const
Definition array.h:331
+
array(const std::complex< float > &val, Dtype dtype=complex64)
+
std::vector< array > & siblings()
The array's siblings.
Definition array.h:265
+
T * data()
Definition array.h:313
+
array(T val, Dtype dtype=TypeToDtype< T >())
Construct a scalar array with zero dimensions.
Definition array.h:451
+
ArrayIterator begin() const
Definition array.h:173
+
Primitive & primitive() const
The array's primitive.
Definition array.h:231
+
void detach()
Detach the array from the graph.
+
array & operator=(const array &other) &&=delete
Assignment to rvalue does not compile.
+
void set_siblings(std::vector< array > siblings, uint16_t position)
Definition array.h:269
+
T item()
Get the value from a scalar array.
Definition array.h:489
+
size_t strides(int dim) const
Get the stride of the corresponding dimension.
Definition array.h:122
+
void copy_shared_buffer(const array &other)
+
void overwrite_descriptor(const array &other)
Definition array.h:379
+
const T * data() const
Definition array.h:318
+
bool has_primitive() const
Check if the array has an attached primitive or is a leaf node.
Definition array.h:241
+
allocator::Buffer & buffer()
Definition array.h:299
+
array(array &&other)=default
+
std::shared_ptr< Data > data_shared_ptr() const
Definition array.h:308
+
void move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
+
const std::vector< array > & siblings() const
The array's siblings.
Definition array.h:260
+
std::vector< array > & inputs()
Definition array.h:250
+
array & operator=(array &&other) &=default
Default copy and move constructors otherwise.
+
array(std::vector< int > shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)
The following methods should be used with caution.
+
const Status status() const
Definition array.h:327
+
std::uintptr_t id() const
A unique identifier for an array.
Definition array.h:199
+
Dtype dtype() const
Get the arrays data type.
Definition array.h:127
+
bool is_available() const
Definition array.h:324
+
void set_tracer(bool is_tracer)
Definition array.h:346
+
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:79
+
std::uintptr_t primitive_id() const
A unique identifier for an arrays primitive.
Definition array.h:204
+
bool is_tracer() const
+
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:295
+ + +
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
+
Buffer malloc(size_t size)
+
void free(Buffer buffer)
+
Definition allocator.h:7
+
constexpr bool is_array_v
Definition array.h:558
+
constexpr Dtype bool_
Definition dtype.h:60
+
std::function< void(allocator::Buffer)> deleter_t
Definition array.h:18
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr bool is_arrays_v
Definition array.h:562
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
uint8_t size_of(const Dtype &t)
Definition dtype.h:95
+
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:565
+
constexpr Dtype complex64
Definition dtype.h:75
+
Definition dtype.h:15
+
Definition dtype.h:102
+
Definition array.h:141
+ +
friend bool operator==(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:161
+
std::random_access_iterator_tag iterator_category
Definition array.h:142
+
ArrayIterator & operator++()
Definition array.h:156
+
friend bool operator!=(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:164
+
ArrayIterator(const array &arr, int idx=0)
+
size_t difference_type
Definition array.h:143
+
const array value_type
Definition array.h:144
+
ArrayIterator & operator+(difference_type diff)
Definition array.h:151
+
Definition array.h:208
+
~Data()
Definition array.h:216
+
deleter_t d
Definition array.h:210
+
Data(const Data &d)=delete
+
Data & operator=(const Data &d)=delete
+
Data(allocator::Buffer buffer, deleter_t d=allocator::free)
Definition array.h:211
+
allocator::Buffer buffer
Definition array.h:209
+
Definition array.h:221
+
bool row_contiguous
Definition array.h:226
+
bool col_contiguous
Definition array.h:227
+
bool contiguous
Definition array.h:224
+
+ + + + diff --git a/docs/build/html/atomic_8h.html b/docs/build/html/atomic_8h.html new file mode 100644 index 000000000..f9df80082 --- /dev/null +++ b/docs/build/html/atomic_8h.html @@ -0,0 +1,521 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/atomic.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Functions | +Variables
+
atomic.h File Reference
+
+
+
#include <metal_atomic>
+#include <metal_stdlib>
+#include "mlx/backend/metal/kernels/bf16.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Classes

struct  mlx_atomic< T, typename >
 
struct  mlx_atomic< T, enable_if_t< is_metal_atomic< T > > >
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC T mlx_atomic_load_explicit (device mlx_atomic< T > *object, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_store_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_and_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_or_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_min_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_max_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_add_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_mul_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit (device mlx_atomic< T > *object, thread T *expected, T val, uint offset)
 
template<>
METAL_FUNC void mlx_atomic_fetch_min_explicit< float > (device mlx_atomic< float > *object, float val, uint offset)
 
template<>
METAL_FUNC void mlx_atomic_fetch_max_explicit< float > (device mlx_atomic< float > *object, float val, uint offset)
 
template<typename T , enable_if_t<!is_metal_atomic< T >, bool > = true>
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit (device mlx_atomic< T > *object, thread uint *expected, uint val, uint offset)
 
+ + + + +

+Variables

template<typename T >
constexpr constant bool is_metal_atomic
 
+

Function Documentation

+ +

◆ mlx_atomic_compare_exchange_weak_explicit() [1/2]

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit (device mlx_atomic< T > * object,
thread T * expected,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_compare_exchange_weak_explicit() [2/2]

+ +
+
+
+template<typename T , enable_if_t<!is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit (device mlx_atomic< T > * object,
thread uint * expected,
uint val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_add_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_add_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_and_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_and_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_max_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_max_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_max_explicit< float >()

+ +
+
+
+template<>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_max_explicit< float > (device mlx_atomic< float > * object,
float val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_min_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_min_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_min_explicit< float >()

+ +
+
+
+template<>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_min_explicit< float > (device mlx_atomic< float > * object,
float val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_mul_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_mul_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_or_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_or_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_load_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + +
METAL_FUNC T mlx_atomic_load_explicit (device mlx_atomic< T > * object,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_store_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_store_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+

Variable Documentation

+ +

◆ is_metal_atomic

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool is_metal_atomic
+
+constexpr
+
+Initial value:
= _disjunction<
+
is_same<T, int>,
+
is_same<T, uint>,
+
is_same<T, ulong>,
+
is_same<T, float>>::value
+
+
+
+
+ + + + diff --git a/docs/build/html/atomic_8h_source.html b/docs/build/html/atomic_8h_source.html new file mode 100644 index 000000000..b1e920a46 --- /dev/null +++ b/docs/build/html/atomic_8h_source.html @@ -0,0 +1,478 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/atomic.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
atomic.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_atomic>
+
6#include <metal_stdlib>
+ +
8
+
9using namespace metal;
+
10
+
12// Atomic utils
+
14
+
15#pragma METAL internals : enable
+
16template <typename T>
+
17constexpr constant bool is_metal_atomic = _disjunction<
+
18 is_same<T, int>,
+
19 is_same<T, uint>,
+
20 is_same<T, ulong>,
+
21 is_same<T, float>>::value;
+
22
+
23#pragma METAL internals : disable
+
24
+
25template <typename T, typename = void>
+
+
26struct mlx_atomic {
+
27 atomic<uint> val;
+
28};
+
+
29
+
30template <typename T>
+
+
31struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
+
32 atomic<T> val;
+
33};
+
+
34
+
36// Native metal atomics
+
38
+
39template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
40METAL_FUNC T
+
+
41mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
+
42 return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
+
43}
+
+
44
+
45template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
46METAL_FUNC void
+
+
47mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
+
48 atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
+
49}
+
+
50
+
51template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
53 device mlx_atomic<T>* object,
+
54 T val,
+
55 uint offset) {
+
56 atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
+
57}
+
+
58
+
59template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
60METAL_FUNC void
+
+
61mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
+
62 atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
+
63}
+
+
64
+
65template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
67 device mlx_atomic<T>* object,
+
68 T val,
+
69 uint offset) {
+
70 atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
+
71}
+
+
72
+
73template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
75 device mlx_atomic<T>* object,
+
76 T val,
+
77 uint offset) {
+
78 atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
+
79}
+
+
80
+
81template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
83 device mlx_atomic<T>* object,
+
84 T val,
+
85 uint offset) {
+
86 atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
+
87}
+
+
88
+
89template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
91 device mlx_atomic<T>* object,
+
92 T val,
+
93 uint offset) {
+
94 T expected = mlx_atomic_load_explicit(object, offset);
+ +
96 object, &expected, val * expected, offset)) {
+
97 }
+
98}
+
+
99
+
100template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
102 device mlx_atomic<T>* object,
+
103 thread T* expected,
+
104 T val,
+
105 uint offset) {
+
106 return atomic_compare_exchange_weak_explicit(
+
107 &(object[offset].val),
+
108 expected,
+
109 val,
+
110 memory_order_relaxed,
+
111 memory_order_relaxed);
+
112}
+
+
113
+
114// Specialization for float since it does not atomic_fetch_min_explicit
+
115template <>
+
+ +
117 device mlx_atomic<float>* object,
+
118 float val,
+
119 uint offset) {
+
120 float expected = mlx_atomic_load_explicit(object, offset);
+
121 while (val < expected) {
+ +
123 object, &expected, val, offset)) {
+
124 return;
+
125 }
+
126 }
+
127}
+
+
128
+
129// Specialization for float since it does not atomic_fetch_max_explicit
+
130template <>
+
+ +
132 device mlx_atomic<float>* object,
+
133 float val,
+
134 uint offset) {
+
135 float expected = mlx_atomic_load_explicit(object, offset);
+
136 while (val > expected) {
+ +
138 object, &expected, val, offset)) {
+
139 return;
+
140 }
+
141 }
+
142}
+
+
143
+
145// Custom atomics
+
147
+
148namespace {
+
149
+
150template <typename T>
+
151constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
+
152
+
153template <typename T>
+
154union uint_or_packed {
+
155 T val[packing_size<T>];
+
156 uint bits;
+
157};
+
158
+
159template <typename T, typename Op>
+
160struct mlx_atomic_update_helper {
+
161 uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
+
162 Op op;
+
163 init.val[elem_offset] = op(update, init.val[elem_offset]);
+
164 return init.bits;
+
165 }
+
166};
+
167
+
168template <typename T, typename Op>
+
169METAL_FUNC void mlx_atomic_update_and_store(
+
170 device mlx_atomic<T>* object,
+
171 T update,
+
172 uint offset) {
+
173 uint pack_offset = offset / packing_size<T>;
+
174 uint elem_offset = offset % packing_size<T>;
+
175
+
176 mlx_atomic_update_helper<T, Op> helper;
+
177 uint_or_packed<T> expected;
+
178 expected.bits =
+
179 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
+
180
+
181 while (Op::condition(update, expected.val[elem_offset]) &&
+ +
183 object,
+
184 &(expected.bits),
+
185 helper(expected, update, elem_offset),
+
186 pack_offset)) {
+
187 }
+
188}
+
189
+
190template <typename T>
+
191struct __None {
+
192 static bool condition(T a, T b) {
+
193#pragma unused(a)
+
194#pragma unused(b)
+
195 return true;
+
196 }
+
197
+
198 T operator()(T a, T b) {
+
199#pragma unused(b)
+
200 return a;
+
201 }
+
202};
+
203
+
204template <typename T>
+
205struct __Add {
+
206 static bool condition(T a, T b) {
+
207#pragma unused(a)
+
208#pragma unused(b)
+
209 return true;
+
210 }
+
211
+
212 T operator()(T a, T b) {
+
213 return a + b;
+
214 }
+
215};
+
216
+
217template <typename T>
+
218struct __Mul {
+
219 static bool condition(T a, T b) {
+
220#pragma unused(a)
+
221 return b != 0;
+
222 }
+
223
+
224 T operator()(T a, T b) {
+
225 return a * b;
+
226 }
+
227};
+
228
+
229template <typename T>
+
230struct __Max {
+
231 static bool condition(T a, T b) {
+
232 return a > b;
+
233 }
+
234
+
235 T operator()(T a, T b) {
+
236 return max(a, b);
+
237 }
+
238};
+
239
+
240template <typename T>
+
241struct __Min {
+
242 static bool condition(T a, T b) {
+
243 return a < b;
+
244 }
+
245
+
246 T operator()(T a, T b) {
+
247 return min(a, b);
+
248 }
+
249};
+
250
+
251} // namespace
+
252
+
253template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
254METAL_FUNC T
+
255mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
+
256 uint pack_offset = offset / sizeof(T);
+
257 uint elem_offset = offset % sizeof(T);
+
258 uint_or_packed<T> packed_val;
+
259 packed_val.bits =
+
260 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
+
261 return packed_val.val[elem_offset];
+
262}
+
263
+
264template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
265METAL_FUNC void
+
266mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
+
267 mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
+
268}
+
269
+
270template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
271METAL_FUNC void mlx_atomic_fetch_and_explicit(
+
272 device mlx_atomic<T>* object,
+
273 T val,
+
274 uint offset) {
+
275 uint pack_offset = offset / packing_size<T>;
+
276 uint elem_offset = offset % packing_size<T>;
+
277 uint_or_packed<T> identity;
+
278 identity.bits = __UINT32_MAX__;
+
279 identity.val[elem_offset] = val;
+
280
+
281 atomic_fetch_and_explicit(
+
282 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
+
283}
+
284
+
285template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
286METAL_FUNC void
+
287mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
+
288 uint pack_offset = offset / packing_size<T>;
+
289 uint elem_offset = offset % packing_size<T>;
+
290 uint_or_packed<T> identity;
+
291 identity.bits = 0;
+
292 identity.val[elem_offset] = val;
+
293
+
294 atomic_fetch_or_explicit(
+
295 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
+
296}
+
297
+
298template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
299METAL_FUNC void mlx_atomic_fetch_min_explicit(
+
300 device mlx_atomic<T>* object,
+
301 T val,
+
302 uint offset) {
+
303 mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
+
304}
+
305
+
306template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
307METAL_FUNC void mlx_atomic_fetch_max_explicit(
+
308 device mlx_atomic<T>* object,
+
309 T val,
+
310 uint offset) {
+
311 mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
+
312}
+
313
+
314template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
315METAL_FUNC void mlx_atomic_fetch_add_explicit(
+
316 device mlx_atomic<T>* object,
+
317 T val,
+
318 uint offset) {
+
319 mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
+
320}
+
321
+
322template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
323METAL_FUNC void mlx_atomic_fetch_mul_explicit(
+
324 device mlx_atomic<T>* object,
+
325 T val,
+
326 uint offset) {
+
327 mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
+
328}
+
329
+
330template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
+ +
332 device mlx_atomic<T>* object,
+
333 thread uint* expected,
+
334 uint val,
+
335 uint offset) {
+
336 return atomic_compare_exchange_weak_explicit(
+
337 &(object[offset].val),
+
338 expected,
+
339 val,
+
340 memory_order_relaxed,
+
341 memory_order_relaxed);
+
342}
+
+
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:82
+
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:131
+
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:52
+
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, uint offset)
Definition atomic.h:41
+
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:47
+
constexpr constant bool is_metal_atomic
Definition atomic.h:17
+
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:61
+
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:116
+
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:74
+
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:66
+
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:90
+
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, uint offset)
Definition atomic.h:101
+ +
Op op
Definition binary.h:139
+
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
+
Definition bf16.h:265
+
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
+
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
+
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
+ +
Definition atomic.h:26
+
atomic< uint > val
Definition atomic.h:27
+
+ + + + diff --git a/docs/build/html/backend_2accelerate_2utils_8h.html b/docs/build/html/backend_2accelerate_2utils_8h.html new file mode 100644 index 000000000..60d5dacd1 --- /dev/null +++ b/docs/build/html/backend_2accelerate_2utils_8h.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: mlx/backend/accelerate/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces | +Functions
+
utils.h File Reference
+
+
+
#include <vecLib/BNNS/bnns.h>
+#include "mlx/dtype.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Functions

BNNSDataType mlx::core::to_bnns_dtype (Dtype mlx_dtype)
 
+
+ + + + diff --git a/docs/build/html/backend_2accelerate_2utils_8h_source.html b/docs/build/html/backend_2accelerate_2utils_8h_source.html new file mode 100644 index 000000000..02e8f28ac --- /dev/null +++ b/docs/build/html/backend_2accelerate_2utils_8h_source.html @@ -0,0 +1,134 @@ + + + + + + + +MLX: mlx/backend/accelerate/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <vecLib/BNNS/bnns.h>
+
6#include "mlx/dtype.h"
+
7
+
8namespace mlx::core {
+
9
+
+
10BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
+
11 uint32_t size_bits = size_of(mlx_dtype) * 8;
+
12 switch (kindof(mlx_dtype)) {
+
13 case Dtype::Kind::b:
+
14 return BNNSDataTypeBoolean;
+
15 case Dtype::Kind::u:
+
16 return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
+
17 case Dtype::Kind::i:
+
18 return BNNSDataType(BNNSDataTypeIntBit | size_bits);
+
19 case Dtype::Kind::f:
+
20 return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
+
21 case Dtype::Kind::V:
+
22 return BNNSDataTypeBFloat16;
+
23 case Dtype::Kind::c:
+
24 throw std::invalid_argument("BNNS does not support complex types");
+
25 }
+
26}
+
+
27
+
28} // namespace mlx::core
+ +
Definition allocator.h:7
+
BNNSDataType to_bnns_dtype(Dtype mlx_dtype)
Definition utils.h:10
+
Dtype::Kind kindof(const Dtype &t)
+
uint8_t size_of(const Dtype &t)
Definition dtype.h:95
+
Definition dtype.h:15
+ + + + + + +
+ + + + diff --git a/docs/build/html/backend_2common_2ops_8h.html b/docs/build/html/backend_2common_2ops_8h.html new file mode 100644 index 000000000..c976bb466 --- /dev/null +++ b/docs/build/html/backend_2common_2ops_8h.html @@ -0,0 +1,234 @@ + + + + + + + +MLX: mlx/backend/common/ops.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Namespaces | +Functions
+
ops.h File Reference
+
+
+
#include <stdint.h>
+#include <cmath>
+#include <complex>
+
+

Go to the source code of this file.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Classes

union  mlx::core::detail::IntOrFloat
 
struct  mlx::core::detail::Abs
 
struct  mlx::core::detail::ArcCos
 
struct  mlx::core::detail::ArcCosh
 
struct  mlx::core::detail::ArcSin
 
struct  mlx::core::detail::ArcSinh
 
struct  mlx::core::detail::ArcTan
 
struct  mlx::core::detail::ArcTan2
 
struct  mlx::core::detail::ArcTanh
 
struct  mlx::core::detail::Ceil
 
struct  mlx::core::detail::Conjugate
 
struct  mlx::core::detail::Cos
 
struct  mlx::core::detail::Cosh
 
struct  mlx::core::detail::Erf
 
struct  mlx::core::detail::ErfInv
 
struct  mlx::core::detail::Exp
 
struct  mlx::core::detail::Expm1
 
struct  mlx::core::detail::Floor
 
struct  mlx::core::detail::Log
 
struct  mlx::core::detail::Log2
 
struct  mlx::core::detail::Log10
 
struct  mlx::core::detail::Log1p
 
struct  mlx::core::detail::LogicalNot
 
struct  mlx::core::detail::Negative
 
struct  mlx::core::detail::Round
 
struct  mlx::core::detail::Sigmoid
 
struct  mlx::core::detail::Sign
 
struct  mlx::core::detail::Sin
 
struct  mlx::core::detail::Sinh
 
struct  mlx::core::detail::Square
 
struct  mlx::core::detail::Sqrt
 
struct  mlx::core::detail::Rsqrt
 
struct  mlx::core::detail::Tan
 
struct  mlx::core::detail::Tanh
 
struct  mlx::core::detail::Add
 
struct  mlx::core::detail::Divide
 
struct  mlx::core::detail::Remainder
 
struct  mlx::core::detail::Equal
 
struct  mlx::core::detail::NaNEqual
 
struct  mlx::core::detail::Greater
 
struct  mlx::core::detail::GreaterEqual
 
struct  mlx::core::detail::Less
 
struct  mlx::core::detail::LessEqual
 
struct  mlx::core::detail::Maximum
 
struct  mlx::core::detail::Minimum
 
struct  mlx::core::detail::LogAddExp
 
struct  mlx::core::detail::Multiply
 
struct  mlx::core::detail::NotEqual
 
struct  mlx::core::detail::Power
 
struct  mlx::core::detail::Subtract
 
struct  mlx::core::detail::LogicalAnd
 
struct  mlx::core::detail::LogicalOr
 
struct  mlx::core::detail::Select
 
struct  mlx::core::detail::BitwiseAnd
 
struct  mlx::core::detail::BitwiseOr
 
struct  mlx::core::detail::BitwiseXor
 
struct  mlx::core::detail::LeftShift
 
struct  mlx::core::detail::RightShift
 
+ + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::detail
 
+ + + + + + + +

+Functions

float mlx::core::detail::fast_exp (float x)
 
float mlx::core::detail::fast_erf (float a)
 
float mlx::core::detail::fast_erfinv (float a)
 
+
+ + + + diff --git a/docs/build/html/backend_2common_2ops_8h_source.html b/docs/build/html/backend_2common_2ops_8h_source.html new file mode 100644 index 000000000..1b6cebe96 --- /dev/null +++ b/docs/build/html/backend_2common_2ops_8h_source.html @@ -0,0 +1,1218 @@ + + + + + + + +MLX: mlx/backend/common/ops.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
ops.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4#include <stdint.h>
+
5#include <cmath>
+
6#include <complex>
+
7
+
+ +
9
+
10namespace {
+
11constexpr float inf = std::numeric_limits<float>::infinity();
+
12} // namespace
+
13
+
+
14typedef union {
+
15 int i;
+
16 float f;
+ +
+
18
+
+
19inline float fast_exp(float x) {
+
20 if (x == -std::numeric_limits<float>::infinity()) {
+
21 return 0.0f;
+
22 } else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
+
23 return x;
+
24 }
+
25 x *= 1.442695; // multiply with log_2(e)
+
26 float ipart, fpart;
+
27 IntOrFloat epart;
+
28 x = std::max(-80.f, std::min(x, 80.f));
+
29 ipart = std::floor(x + 0.5);
+
30 fpart = x - ipart;
+
31
+
32 x = 1.535336188319500e-4f;
+
33 x = x * fpart + 1.339887440266574e-3f;
+
34 x = x * fpart + 9.618437357674640e-3f;
+
35 x = x * fpart + 5.550332471162809e-2f;
+
36 x = x * fpart + 2.402264791363012e-1f;
+
37 x = x * fpart + 6.931472028550421e-1f;
+
38 x = x * fpart + 1.000000000000000f;
+
39
+
40 // generate 2**ipart in the floating point representation using integer
+
41 // bitshifting
+
42 epart.i = (int(ipart) + 127) << 23;
+
43
+
44 return epart.f * x;
+
45}
+
+
46
+
+
47inline float fast_erf(float a) {
+
48 float r, s, t, u;
+
49 t = std::abs(a);
+
50 s = a * a;
+
51 if (t > 0.927734375f) {
+
52 // maximum error 0.99527 ulp
+
53 r = std::fma(
+
54 -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
+
55 u = std::fma(
+
56 -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
+
57 r = std::fma(r, s, u);
+
58 r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
+
59 r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
+
60 r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
+
61 r = std::fma(r, t, -t);
+
62 // TODO, replace with expm1 when implemented
+
63 r = 1.0f - std::exp(r);
+
64 r = std::copysign(r, a);
+
65 } else {
+
66 // maximum error 0.98929 ulp
+
67 r = -5.96761703e-4f; // -0x1.38e000p-11
+
68 r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
+
69 r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
+
70 r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
+
71 r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
+
72 r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
+
73 r = std::fma(r, a, a);
+
74 }
+
75 return r;
+
76}
+
+
77
+
+
78inline float fast_erfinv(float a) {
+
79 auto t = std::fma(a, 0.0f - a, 1.0f);
+
80 t = std::log(t);
+
81 float p;
+
82 if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
+
83 p = 3.03697567e-10f; // 0x1.4deb44p-32
+
84 p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
+
85 p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
+
86 p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
+
87 p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
+
88 p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
+
89 p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
+
90 p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
+
91 p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
+
92 } else { // maximum ulp error = 2.35002
+
93 p = 5.43877832e-9f; // 0x1.75c000p-28
+
94 p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
+
95 p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
+
96 p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
+
97 p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
+
98 p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
+
99 p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
+
100 p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
+
101 p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
+
102 p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
+
103 }
+
104 return a * p;
+
105}
+
+
106
+
+
107struct Abs {
+
108 template <typename T>
+
+
109 T operator()(T x) {
+
110 return std::abs(x);
+
111 };
+
+
+
112 uint8_t operator()(uint8_t x) {
+
113 return x;
+
114 };
+
+
+
115 uint16_t operator()(uint16_t x) {
+
116 return x;
+
117 };
+
+
+
118 uint32_t operator()(uint32_t x) {
+
119 return x;
+
120 };
+
+
+
121 uint64_t operator()(uint64_t x) {
+
122 return x;
+
123 };
+
+
+
124 bool operator()(bool x) {
+
125 return x;
+
126 };
+
+
127};
+
+
128
+
+
129struct ArcCos {
+
130 template <typename T>
+
+
131 T operator()(T x) {
+
132 return std::acos(x);
+
133 };
+
+
134};
+
+
135
+
+
136struct ArcCosh {
+
137 template <typename T>
+
+
138 T operator()(T x) {
+
139 return std::acosh(x);
+
140 };
+
+
141};
+
+
142
+
+
143struct ArcSin {
+
144 template <typename T>
+
+
145 T operator()(T x) {
+
146 return std::asin(x);
+
147 };
+
+
148};
+
+
149
+
+
150struct ArcSinh {
+
151 template <typename T>
+
+
152 T operator()(T x) {
+
153 return std::asinh(x);
+
154 };
+
+
155};
+
+
156
+
+
157struct ArcTan {
+
158 template <typename T>
+
+
159 T operator()(T x) {
+
160 return std::atan(x);
+
161 };
+
+
162};
+
+
163
+
+
164struct ArcTan2 {
+
165 template <typename T>
+
+
166 T operator()(T y, T x) {
+
167 return std::atan2(y, x);
+
168 };
+
+
169};
+
+
170
+
+
171struct ArcTanh {
+
172 template <typename T>
+
+
173 T operator()(T x) {
+
174 return std::atanh(x);
+
175 };
+
+
176};
+
+
177
+
+
178struct Ceil {
+
179 template <typename T>
+
+
180 T operator()(T x) {
+
181 return std::ceil(x);
+
182 };
+
+
+
183 int8_t operator()(int8_t x) {
+
184 return x;
+
185 };
+
+
+
186 int16_t operator()(int16_t x) {
+
187 return x;
+
188 };
+
+
+
189 int32_t operator()(int32_t x) {
+
190 return x;
+
191 };
+
+
+
192 int64_t operator()(int64_t x) {
+
193 return x;
+
194 };
+
+
+
195 uint8_t operator()(uint8_t x) {
+
196 return x;
+
197 };
+
+
+
198 uint16_t operator()(uint16_t x) {
+
199 return x;
+
200 };
+
+
+
201 uint32_t operator()(uint32_t x) {
+
202 return x;
+
203 };
+
+
+
204 uint64_t operator()(uint64_t x) {
+
205 return x;
+
206 };
+
+
+
207 bool operator()(bool x) {
+
208 return x;
+
209 };
+
+
210};
+
+
211
+
+
212struct Conjugate {
+
+ +
214 return std::conj(x);
+
215 }
+
+
216};
+
+
217
+
+
218struct Cos {
+
219 template <typename T>
+
+
220 T operator()(T x) {
+
221 return std::cos(x);
+
222 };
+
+
223};
+
+
224
+
+
225struct Cosh {
+
226 template <typename T>
+
+
227 T operator()(T x) {
+
228 return std::cosh(x);
+
229 };
+
+
230};
+
+
231
+
+
232struct Erf {
+
233 template <typename T>
+
+
234 T operator()(T x) {
+
235 return static_cast<T>(fast_erf(static_cast<float>(x)));
+
236 };
+
+
237};
+
+
238
+
+
239struct ErfInv {
+
240 template <typename T>
+
+
241 T operator()(T x) {
+
242 return static_cast<T>(fast_erfinv(static_cast<float>(x)));
+
243 };
+
+
244};
+
+
245
+
+
246struct Exp {
+
247 template <typename T>
+
+
248 T operator()(T x) {
+
249 return fast_exp(x);
+
250 };
+
+
251
+
+ +
253 return std::exp(x);
+
254 }
+
+
255};
+
+
256
+
+
257struct Expm1 {
+
258 template <typename T>
+
+
259 T operator()(T x) {
+
260 return expm1(x);
+
261 };
+
+
262};
+
+
263
+
+
264struct Floor {
+
265 template <typename T>
+
+
266 T operator()(T x) {
+
267 return std::floor(x);
+
268 };
+
+
+
269 int8_t operator()(int8_t x) {
+
270 return x;
+
271 };
+
+
+
272 int16_t operator()(int16_t x) {
+
273 return x;
+
274 };
+
+
+
275 int32_t operator()(int32_t x) {
+
276 return x;
+
277 };
+
+
+
278 int64_t operator()(int64_t x) {
+
279 return x;
+
280 };
+
+
+
281 uint8_t operator()(uint8_t x) {
+
282 return x;
+
283 };
+
+
+
284 uint16_t operator()(uint16_t x) {
+
285 return x;
+
286 };
+
+
+
287 uint32_t operator()(uint32_t x) {
+
288 return x;
+
289 };
+
+
+
290 uint64_t operator()(uint64_t x) {
+
291 return x;
+
292 };
+
+
+
293 bool operator()(bool x) {
+
294 return x;
+
295 };
+
+
296};
+
+
297
+
+
298struct Log {
+
299 template <typename T>
+
+
300 T operator()(T x) {
+
301 return std::log(x);
+
302 };
+
+
303};
+
+
304
+
+
305struct Log2 {
+
306 template <typename T>
+
+
307 T operator()(T x) {
+
308 return std::log2(x);
+
309 };
+
+
310};
+
+
311
+
+
312struct Log10 {
+
313 template <typename T>
+
+
314 T operator()(T x) {
+
315 return std::log10(x);
+
316 };
+
+
317};
+
+
318
+
+
319struct Log1p {
+
320 template <typename T>
+
+
321 T operator()(T x) {
+
322 return log1p(x);
+
323 };
+
+
324};
+
+
325
+
+ +
327 template <typename T>
+
+
328 T operator()(T x) {
+
329 return !x;
+
330 };
+
+
331};
+
+
332
+
+
333struct Negative {
+
334 template <typename T>
+
+
335 T operator()(T x) {
+
336 return -x;
+
337 };
+
+
338};
+
+
339
+
+
340struct Round {
+
341 template <typename T>
+
+
342 T operator()(T x) {
+
343 return std::rint(x);
+
344 }
+
+
345
+
+ +
347 return {std::rint(x.real()), std::rint(x.imag())};
+
348 }
+
+
349};
+
+
350
+
+
351struct Sigmoid {
+
352 template <typename T>
+
+
353 T operator()(T x) {
+
354 auto one = static_cast<decltype(x)>(1.0);
+
355 return one / (one + fast_exp(-x));
+
356 }
+
+
357};
+
+
358
+
+
359struct Sign {
+
360 template <typename T>
+
+
361 T operator()(T x) {
+
362 return (x > T(0)) - (x < T(0));
+
363 }
+
+
+
364 uint8_t operator()(uint8_t x) {
+
365 return x != 0;
+
366 }
+
+
+
367 uint16_t operator()(uint16_t x) {
+
368 return x != 0;
+
369 }
+
+
+
370 uint32_t operator()(uint32_t x) {
+
371 return x != 0;
+
372 }
+
+
+
373 uint64_t operator()(uint64_t x) {
+
374 return x != 0;
+
375 }
+
+
376};
+
+
377
+
+
378struct Sin {
+
379 template <typename T>
+
+
380 T operator()(T x) {
+
381 return std::sin(x);
+
382 };
+
+
383};
+
+
384
+
+
385struct Sinh {
+
386 template <typename T>
+
+
387 T operator()(T x) {
+
388 return std::sinh(x);
+
389 };
+
+
390};
+
+
391
+
+
392struct Square {
+
393 template <typename T>
+
+
394 T operator()(T x) {
+
395 return x * x;
+
396 };
+
+
397};
+
+
398
+
+
399struct Sqrt {
+
400 template <typename T>
+
+
401 T operator()(T x) {
+
402 return std::sqrt(x);
+
403 };
+
+
404};
+
+
405
+
+
406struct Rsqrt {
+
407 template <typename T>
+
+
408 T operator()(T x) {
+
409 return static_cast<decltype(x)>(1.0) / std::sqrt(x);
+
410 };
+
+
411};
+
+
412
+
+
413struct Tan {
+
414 template <typename T>
+
+
415 T operator()(T x) {
+
416 return std::tan(x);
+
417 };
+
+
418};
+
+
419
+
+
420struct Tanh {
+
421 template <typename T>
+
+
422 T operator()(T x) {
+
423 return std::tanh(x);
+
424 };
+
+
425};
+
+
426
+
+
427struct Add {
+
428 template <typename T>
+
+
429 T operator()(T x, T y) {
+
430 return x + y;
+
431 }
+
+
432};
+
+
433
+
+
434struct Divide {
+
435 template <typename T>
+
+
436 T operator()(T x, T y) {
+
437 return x / y;
+
438 }
+
+
439};
+
+
440
+
+
441struct Remainder {
+
442 template <typename T>
+
+
443 std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
+
444 T numerator,
+
445 T denominator) {
+
446 return numerator % denominator;
+
447 }
+
+
448
+
449 template <typename T>
+
+
450 std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
+
451 T numerator,
+
452 T denominator) {
+
453 auto r = numerator % denominator;
+
454 if (r != 0 && (r < 0 != denominator < 0))
+
455 r += denominator;
+
456 return r;
+
457 }
+
+
458
+
459 template <typename T>
+
+
460 std::enable_if_t<!std::is_integral_v<T>, T> operator()(
+
461 T numerator,
+
462 T denominator) {
+
463 auto r = std::fmod(numerator, denominator);
+
464 if (r != 0 && (r < 0 != denominator < 0)) {
+
465 r += denominator;
+
466 }
+
467 return r;
+
468 }
+
+
469
+
+ +
471 return numerator % denominator;
+
472 }
+
+
473};
+
+
474
+
+
475struct Equal {
+
476 template <typename T>
+
+
477 bool operator()(T x, T y) {
+
478 return x == y;
+
479 }
+
+
480};
+
+
481
+
+
482struct NaNEqual {
+
483 template <typename T>
+
+
484 bool operator()(T x, T y) {
+
485 return x == y || (std::isnan(x) && std::isnan(y));
+
486 }
+
+
487};
+
+
488
+
+
489struct Greater {
+
490 template <typename T>
+
+
491 bool operator()(T x, T y) {
+
492 return x > y;
+
493 }
+
+
494};
+
+
495
+
+ +
497 template <typename T>
+
+
498 bool operator()(T x, T y) {
+
499 return x >= y;
+
500 }
+
+
501};
+
+
502
+
+
503struct Less {
+
504 template <typename T>
+
+
505 bool operator()(T x, T y) {
+
506 return x < y;
+
507 }
+
+
508};
+
+
509
+
+
510struct LessEqual {
+
511 template <typename T>
+
+
512 bool operator()(T x, T y) {
+
513 return x <= y;
+
514 }
+
+
515};
+
+
516
+
+
517struct Maximum {
+
518 template <typename T>
+
+
519 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
+
520 return (x > y) ? x : y;
+
521 }
+
+
522
+
523 template <typename T>
+
+
524 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
+
525 if (std::isnan(x)) {
+
526 return x;
+
527 }
+
528 return (x > y) ? x : y;
+
529 }
+
+
530};
+
+
531
+
+
532struct Minimum {
+
533 template <typename T>
+
+
534 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
+
535 return x < y ? x : y;
+
536 }
+
+
537
+
538 template <typename T>
+
+
539 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
+
540 if (std::isnan(x)) {
+
541 return x;
+
542 }
+
543 return x < y ? x : y;
+
544 }
+
+
545};
+
+
546
+
+
547struct LogAddExp {
+
548 template <typename T>
+
+
549 T operator()(T x, T y) {
+
550 constexpr float inf = std::numeric_limits<float>::infinity();
+
551 auto maxval = Maximum()(x, y);
+
552 auto minval = Minimum()(x, y);
+
553 return (minval == -inf || maxval == inf)
+
554 ? maxval
+
555 : static_cast<decltype(x)>(
+
556 maxval + std::log1p(fast_exp(minval - maxval)));
+
557 };
+
+
558};
+
+
559
+
+
560struct Multiply {
+
561 template <typename T>
+
+
562 T operator()(T x, T y) {
+
563 return x * y;
+
564 }
+
+
565};
+
+
566
+
+
567struct NotEqual {
+
568 template <typename T>
+
+
569 bool operator()(T x, T y) {
+
570 return x != y;
+
571 }
+
+
572};
+
+
573
+
+
574struct Power {
+
575 template <typename T>
+
+
576 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
+
577 return std::pow(base, exp);
+
578 }
+
+
579
+
580 template <typename T>
+
+
581 std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
+
582 T res = 1;
+
583 while (exp) {
+
584 if (exp & 1) {
+
585 res *= base;
+
586 }
+
587 exp >>= 1;
+
588 base *= base;
+
589 }
+
590 return res;
+
591 }
+
+
592};
+
+
593
+
+
594struct Subtract {
+
595 template <typename T>
+
+
596 T operator()(T x, T y) {
+
597 return x - y;
+
598 }
+
+
599};
+
+
600
+
+ +
602 template <typename T>
+
+
603 T operator()(T x, T y) {
+
604 return x && y;
+
605 };
+
+
606};
+
+
607
+
+
608struct LogicalOr {
+
609 template <typename T>
+
+
610 T operator()(T x, T y) {
+
611 return x || y;
+
612 };
+
+
613};
+
+
614
+
+
615struct Select {
+
616 template <typename T>
+
+
617 T operator()(bool condition, T x, T y) {
+
618 return condition ? x : y;
+
619 }
+
+
620};
+
+
621
+
+ +
623 template <typename T>
+
+
624 T operator()(T x, T y) {
+
625 return x & y;
+
626 };
+
+
627};
+
+
628
+
+
629struct BitwiseOr {
+
630 template <typename T>
+
+
631 T operator()(T x, T y) {
+
632 return x | y;
+
633 };
+
+
634};
+
+
635
+
+ +
637 template <typename T>
+
+
638 T operator()(T x, T y) {
+
639 return x ^ y;
+
640 };
+
+
641};
+
+
642
+
+
643struct LeftShift {
+
644 template <typename T>
+
+
645 T operator()(T x, T y) {
+
646 return x << y;
+
647 };
+
+
648};
+
+
649
+
+ +
651 template <typename T>
+
+
652 T operator()(T x, T y) {
+
653 return x >> y;
+
654 };
+
+
655};
+
+
656
+
657} // namespace mlx::core::detail
+
+
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
+
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
+
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
+
Definition ops.h:8
+
float fast_exp(float x)
Definition ops.h:19
+
float fast_erf(float a)
Definition ops.h:47
+
float fast_erfinv(float a)
Definition ops.h:78
+
Definition complex.h:34
+
Definition ops.h:107
+
T operator()(T x)
Definition ops.h:109
+
uint8_t operator()(uint8_t x)
Definition ops.h:112
+
uint64_t operator()(uint64_t x)
Definition ops.h:121
+
uint16_t operator()(uint16_t x)
Definition ops.h:115
+
bool operator()(bool x)
Definition ops.h:124
+
uint32_t operator()(uint32_t x)
Definition ops.h:118
+
Definition ops.h:427
+
T operator()(T x, T y)
Definition ops.h:429
+
Definition ops.h:129
+
T operator()(T x)
Definition ops.h:131
+
Definition ops.h:136
+
T operator()(T x)
Definition ops.h:138
+
Definition ops.h:143
+
T operator()(T x)
Definition ops.h:145
+
Definition ops.h:150
+
T operator()(T x)
Definition ops.h:152
+
Definition ops.h:164
+
T operator()(T y, T x)
Definition ops.h:166
+
Definition ops.h:157
+
T operator()(T x)
Definition ops.h:159
+
Definition ops.h:171
+
T operator()(T x)
Definition ops.h:173
+
Definition ops.h:622
+
T operator()(T x, T y)
Definition ops.h:624
+
Definition ops.h:629
+
T operator()(T x, T y)
Definition ops.h:631
+
Definition ops.h:636
+
T operator()(T x, T y)
Definition ops.h:638
+
Definition ops.h:178
+
uint8_t operator()(uint8_t x)
Definition ops.h:195
+
T operator()(T x)
Definition ops.h:180
+
uint32_t operator()(uint32_t x)
Definition ops.h:201
+
int8_t operator()(int8_t x)
Definition ops.h:183
+
int16_t operator()(int16_t x)
Definition ops.h:186
+
bool operator()(bool x)
Definition ops.h:207
+
uint16_t operator()(uint16_t x)
Definition ops.h:198
+
uint64_t operator()(uint64_t x)
Definition ops.h:204
+
int32_t operator()(int32_t x)
Definition ops.h:189
+
int64_t operator()(int64_t x)
Definition ops.h:192
+
Definition ops.h:212
+
complex64_t operator()(complex64_t x)
Definition ops.h:213
+
Definition ops.h:218
+
T operator()(T x)
Definition ops.h:220
+
Definition ops.h:225
+
T operator()(T x)
Definition ops.h:227
+
Definition ops.h:434
+
T operator()(T x, T y)
Definition ops.h:436
+
Definition ops.h:475
+
bool operator()(T x, T y)
Definition ops.h:477
+
Definition ops.h:232
+
T operator()(T x)
Definition ops.h:234
+
Definition ops.h:239
+
T operator()(T x)
Definition ops.h:241
+
Definition ops.h:246
+
T operator()(T x)
Definition ops.h:248
+
complex64_t operator()(complex64_t x)
Definition ops.h:252
+
Definition ops.h:257
+
T operator()(T x)
Definition ops.h:259
+
Definition ops.h:264
+
T operator()(T x)
Definition ops.h:266
+
uint32_t operator()(uint32_t x)
Definition ops.h:287
+
uint16_t operator()(uint16_t x)
Definition ops.h:284
+
uint8_t operator()(uint8_t x)
Definition ops.h:281
+
int32_t operator()(int32_t x)
Definition ops.h:275
+
int64_t operator()(int64_t x)
Definition ops.h:278
+
bool operator()(bool x)
Definition ops.h:293
+
int8_t operator()(int8_t x)
Definition ops.h:269
+
uint64_t operator()(uint64_t x)
Definition ops.h:290
+
int16_t operator()(int16_t x)
Definition ops.h:272
+ +
bool operator()(T x, T y)
Definition ops.h:498
+
Definition ops.h:489
+
bool operator()(T x, T y)
Definition ops.h:491
+
Definition ops.h:643
+
T operator()(T x, T y)
Definition ops.h:645
+
Definition ops.h:510
+
bool operator()(T x, T y)
Definition ops.h:512
+
Definition ops.h:503
+
bool operator()(T x, T y)
Definition ops.h:505
+
Definition ops.h:312
+
T operator()(T x)
Definition ops.h:314
+
Definition ops.h:319
+
T operator()(T x)
Definition ops.h:321
+
Definition ops.h:305
+
T operator()(T x)
Definition ops.h:307
+
Definition ops.h:547
+
T operator()(T x, T y)
Definition ops.h:549
+
Definition ops.h:298
+
T operator()(T x)
Definition ops.h:300
+
Definition ops.h:601
+
T operator()(T x, T y)
Definition ops.h:603
+
Definition ops.h:326
+
T operator()(T x)
Definition ops.h:328
+
Definition ops.h:608
+
T operator()(T x, T y)
Definition ops.h:610
+
Definition ops.h:517
+
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:519
+
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:524
+
Definition ops.h:532
+
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:539
+
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:534
+
Definition ops.h:560
+
T operator()(T x, T y)
Definition ops.h:562
+
Definition ops.h:482
+
bool operator()(T x, T y)
Definition ops.h:484
+
Definition ops.h:333
+
T operator()(T x)
Definition ops.h:335
+
Definition ops.h:567
+
bool operator()(T x, T y)
Definition ops.h:569
+
Definition ops.h:574
+
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:576
+
std::enable_if_t< std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:581
+
Definition ops.h:441
+
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:460
+
std::enable_if_t< std::is_integral_v< T > &!std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:443
+
std::enable_if_t< std::is_integral_v< T > &std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:450
+
complex64_t operator()(complex64_t numerator, complex64_t denominator)
Definition ops.h:470
+
Definition ops.h:650
+
T operator()(T x, T y)
Definition ops.h:652
+
Definition ops.h:340
+
T operator()(T x)
Definition ops.h:342
+
complex64_t operator()(complex64_t x)
Definition ops.h:346
+
Definition ops.h:406
+
T operator()(T x)
Definition ops.h:408
+
Definition ops.h:615
+
T operator()(bool condition, T x, T y)
Definition ops.h:617
+
Definition ops.h:351
+
T operator()(T x)
Definition ops.h:353
+
Definition ops.h:359
+
uint64_t operator()(uint64_t x)
Definition ops.h:373
+
T operator()(T x)
Definition ops.h:361
+
uint8_t operator()(uint8_t x)
Definition ops.h:364
+
uint16_t operator()(uint16_t x)
Definition ops.h:367
+
uint32_t operator()(uint32_t x)
Definition ops.h:370
+
Definition ops.h:378
+
T operator()(T x)
Definition ops.h:380
+
Definition ops.h:385
+
T operator()(T x)
Definition ops.h:387
+
Definition ops.h:399
+
T operator()(T x)
Definition ops.h:401
+
Definition ops.h:392
+
T operator()(T x)
Definition ops.h:394
+
Definition ops.h:594
+
T operator()(T x, T y)
Definition ops.h:596
+
Definition ops.h:413
+
T operator()(T x)
Definition ops.h:415
+
Definition ops.h:420
+
T operator()(T x)
Definition ops.h:422
+
uint32_t u
Definition bf16.h:17
+ +
float f
Definition ops.h:16
+
int i
Definition ops.h:15
+
+ + + + diff --git a/docs/build/html/backend_2common_2utils_8h.html b/docs/build/html/backend_2common_2utils_8h.html new file mode 100644 index 000000000..8789ebdcf --- /dev/null +++ b/docs/build/html/backend_2common_2utils_8h.html @@ -0,0 +1,121 @@ + + + + + + + +MLX: mlx/backend/common/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces | +Functions
+
utils.h File Reference
+
+
+
#include <vector>
+#include "mlx/array.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + + + + + + + + + + + + + + + +

+Functions

template<typename stride_t >
stride_t mlx::core::elem_to_loc (int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
 
size_t mlx::core::elem_to_loc (int elem, const array &a)
 
template<typename stride_t >
std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > mlx::core::collapse_contiguous_dims (const std::vector< int > &shape, const std::vector< std::vector< stride_t > > strides)
 
std::tuple< std::vector< int >, std::vector< std::vector< size_t > > > mlx::core::collapse_contiguous_dims (const std::vector< array > &xs)
 
template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
auto mlx::core::collapse_contiguous_dims (Arrays &&... xs)
 
template<typename stride_t >
auto mlx::core::check_contiguity (const std::vector< int > &shape, const std::vector< stride_t > &strides)
 
+
+ + + + diff --git a/docs/build/html/backend_2common_2utils_8h_source.html b/docs/build/html/backend_2common_2utils_8h_source.html new file mode 100644 index 000000000..b2b4439da --- /dev/null +++ b/docs/build/html/backend_2common_2utils_8h_source.html @@ -0,0 +1,236 @@ + + + + + + + +MLX: mlx/backend/common/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <vector>
+
6
+
7#include "mlx/array.h"
+
8
+
9namespace mlx::core {
+
10
+
11template <typename stride_t>
+
+
12inline stride_t elem_to_loc(
+
13 int elem,
+
14 const std::vector<int>& shape,
+
15 const std::vector<stride_t>& strides) {
+
16 stride_t loc = 0;
+
17 for (int i = shape.size() - 1; i >= 0; --i) {
+
18 auto q_and_r = ldiv(elem, shape[i]);
+
19 loc += q_and_r.rem * strides[i];
+
20 elem = q_and_r.quot;
+
21 }
+
22 return loc;
+
23}
+
+
24
+
+
25inline size_t elem_to_loc(int elem, const array& a) {
+
26 if (a.flags().row_contiguous) {
+
27 return elem;
+
28 }
+
29 return elem_to_loc(elem, a.shape(), a.strides());
+
30}
+
+
31
+
32// Collapse dims that are contiguous to possibly route to a better kernel
+
33// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
+
34// should return {{2, 4}, {{1, 2}}}.
+
35//
+
36// When multiple arrays are passed they should all have the same shape. The
+
37// collapsed axes are also the same so one shape is returned.
+
38template <typename stride_t>
+
39inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
+
+ +
41 const std::vector<int>& shape,
+
42 const std::vector<std::vector<stride_t>> strides) {
+
43 // Make a vector that has axes separated with -1. Collapse all axes between
+
44 // -1.
+
45 std::vector<int> to_collapse;
+
46 if (shape.size() > 0) {
+
47 to_collapse.push_back(0);
+
48 for (int i = 1; i < shape.size(); i++) {
+
49 bool contiguous = true;
+
50 for (const std::vector<stride_t>& st : strides) {
+
51 if (st[i] * shape[i] != st[i - 1]) {
+
52 contiguous = false;
+
53 }
+
54 if (!contiguous) {
+
55 break;
+
56 }
+
57 }
+
58 if (!contiguous) {
+
59 to_collapse.push_back(-1);
+
60 }
+
61 to_collapse.push_back(i);
+
62 }
+
63 to_collapse.push_back(-1);
+
64 }
+
65
+
66 std::vector<int> out_shape;
+
67 std::vector<std::vector<stride_t>> out_strides(strides.size());
+
68 for (int i = 0; i < to_collapse.size(); i++) {
+
69 int current_shape = shape[to_collapse[i]];
+
70 while (to_collapse[++i] != -1) {
+
71 current_shape *= shape[to_collapse[i]];
+
72 }
+
73 out_shape.push_back(current_shape);
+
74 for (int j = 0; j < strides.size(); j++) {
+
75 const std::vector<stride_t>& st = strides[j];
+
76 out_strides[j].push_back(st[to_collapse[i - 1]]);
+
77 }
+
78 }
+
79
+
80 return std::make_tuple(out_shape, out_strides);
+
81}
+
+
82
+
83inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
+
+
84collapse_contiguous_dims(const std::vector<array>& xs) {
+
85 std::vector<std::vector<size_t>> strides;
+
86 for (auto& x : xs) {
+
87 strides.emplace_back(x.strides());
+
88 }
+
89 return collapse_contiguous_dims(xs[0].shape(), strides);
+
90}
+
+
91
+
92template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
+
+
93inline auto collapse_contiguous_dims(Arrays&&... xs) {
+ +
95 std::vector<array>{std::forward<Arrays>(xs)...});
+
96}
+
+
97
+
98template <typename stride_t>
+
+
99inline auto check_contiguity(
+
100 const std::vector<int>& shape,
+
101 const std::vector<stride_t>& strides) {
+
102 size_t data_size = 1;
+
103 size_t f_stride = 1;
+
104 size_t b_stride = 1;
+
105 bool is_row_contiguous = true;
+
106 bool is_col_contiguous = true;
+
107
+
108 for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
+
109 is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
+
110 is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
+
111 f_stride *= shape[i];
+
112 b_stride *= shape[ri];
+
113 if (strides[i] > 0) {
+
114 data_size *= shape[i];
+
115 }
+
116 }
+
117
+
118 return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
+
119}
+
+
120
+
121} // namespace mlx::core
+ +
Definition array.h:20
+
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:290
+
const std::vector< size_t > & strides() const
The strides of the array.
Definition array.h:113
+
const std::vector< int > & shape() const
The shape of the array as a vector of integers.
Definition array.h:99
+
Definition allocator.h:7
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
auto check_contiguity(const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:99
+
std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< stride_t > > strides)
Definition utils.h:40
+
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:565
+
bool row_contiguous
Definition array.h:226
+
+ + + + diff --git a/docs/build/html/backend_2metal_2allocator_8h.html b/docs/build/html/backend_2metal_2allocator_8h.html new file mode 100644 index 000000000..a8d38c66a --- /dev/null +++ b/docs/build/html/backend_2metal_2allocator_8h.html @@ -0,0 +1,161 @@ + + + + + + + +MLX: mlx/backend/metal/allocator.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Namespaces | +Functions
+
allocator.h File Reference
+
+
+
#include <map>
+#include <mutex>
+#include <vector>
+#include "mlx/allocator.h"
+#include "mlx/backend/metal/device.h"
+
+

Go to the source code of this file.

+ + + + +

+Classes

class  mlx::core::metal::MetalAllocator
 
+ + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::metal
 
+ + + +

+Functions

MetalAllocatormlx::core::metal::allocator ()
 
+

Variable Documentation

+ +

◆ buf

+ +
+
+ + + + +
MTL::Buffer* buf
+
+ +
+
+ +

◆ next

+ +
+
+ + + + +
BufferHolder* next
+
+ +
+
+ +

◆ prev

+ +
+
+ + + + +
BufferHolder* prev
+
+ +
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2allocator_8h_source.html b/docs/build/html/backend_2metal_2allocator_8h_source.html new file mode 100644 index 000000000..6588aac61 --- /dev/null +++ b/docs/build/html/backend_2metal_2allocator_8h_source.html @@ -0,0 +1,221 @@ + + + + + + + +MLX: mlx/backend/metal/allocator.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
allocator.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <map>
+
6#include <mutex>
+
7#include <vector>
+
8
+
9#include "mlx/allocator.h"
+ +
11
+
+ +
13
+ +
15
+
16namespace {
+
17
+
18class BufferCache {
+
19 public:
+
20 BufferCache(MTL::Device* device);
+
21 ~BufferCache();
+
22
+
23 MTL::Buffer* reuse_from_cache(size_t size);
+
24 void recycle_to_cache(MTL::Buffer* buf);
+
25 void release_cached_buffers(size_t min_bytes_to_free);
+
26 size_t cache_size() {
+
27 return pool_size_;
+
28 }
+
29 void clear();
+
30
+
31 private:
+
32 struct BufferHolder {
+
33 public:
+
34 BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
+
35
+
36 BufferHolder* prev;
+
37 BufferHolder* next;
+
38 MTL::Buffer* buf;
+
39 };
+
40
+
41 void add_at_head(BufferHolder* to_add);
+
42 void remove_from_list(BufferHolder* to_remove);
+
43
+
44 MTL::Device* device_;
+
45
+
46 std::multimap<size_t, BufferHolder*> buffer_pool_;
+
47 BufferHolder* head_;
+
48 BufferHolder* tail_;
+
49 size_t pool_size_;
+
50};
+
51
+
52} // namespace
+
53
+
+ +
56 public:
+
57 virtual Buffer malloc(size_t size, bool allow_swap = false) override;
+
58 virtual void free(Buffer buffer) override;
+
+ +
60 return active_memory_;
+
61 };
+
+
+
62 size_t get_peak_memory() {
+
63 return peak_memory_;
+
64 };
+
+
+ +
66 std::unique_lock lk(mutex_);
+
67 peak_memory_ = 0;
+
68 };
+
+
+ +
70 return buffer_cache_.cache_size();
+
71 };
+
+
72 size_t set_cache_limit(size_t limit);
+
73 size_t set_memory_limit(size_t limit, bool relaxed);
+ +
75
+
76 private:
+
77 MTL::Device* device_;
+ + +
80
+
81 // Caching allocator
+
82 BufferCache buffer_cache_;
+
83
+
84 // Allocation stats
+
85 size_t block_limit_;
+
86 size_t gc_limit_;
+
87 size_t active_memory_{0};
+
88 size_t peak_memory_{0};
+
89 size_t max_pool_size_;
+
90 bool relaxed_{true};
+
91
+
92 std::mutex mutex_;
+
93};
+
+
94
+ +
96
+
97} // namespace mlx::core::metal
+
+ +
MTL::Buffer * buf
Definition allocator.h:38
+
BufferHolder * prev
Definition allocator.h:36
+
BufferHolder * next
Definition allocator.h:37
+ +
Definition allocator.h:39
+
Definition allocator.h:12
+
Definition allocator.h:54
+
virtual void free(Buffer buffer) override
+
size_t set_memory_limit(size_t limit, bool relaxed)
+
void reset_peak_memory()
Definition allocator.h:65
+ +
virtual Buffer malloc(size_t size, bool allow_swap=false) override
Allocator for Metal GPUs.
+
size_t get_active_memory()
Definition allocator.h:59
+
size_t get_peak_memory()
Definition allocator.h:62
+
size_t get_cache_memory()
Definition allocator.h:69
+
size_t set_cache_limit(size_t limit)
+
friend MetalAllocator & allocator()
+
Definition allocator.h:12
+
MetalAllocator & allocator()
+
Device & device(mlx::core::Device)
+
+ + + + diff --git a/docs/build/html/backend_2metal_2device_8h.html b/docs/build/html/backend_2metal_2device_8h.html new file mode 100644 index 000000000..573f61f73 --- /dev/null +++ b/docs/build/html/backend_2metal_2device_8h.html @@ -0,0 +1,135 @@ + + + + + + + +MLX: mlx/backend/metal/device.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Namespaces | +Typedefs | +Functions
+
device.h File Reference
+
+
+
#include <Metal/Metal.hpp>
+#include <functional>
+#include <mutex>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <dlfcn.h>
+#include <filesystem>
+#include "mlx/array.h"
+#include "mlx/device.h"
+
+

Go to the source code of this file.

+ + + + + + + + +

+Classes

struct  mlx::core::metal::CommandEncoder
 
struct  mlx::core::metal::CommandEncoder::ConcurrentContext
 
class  mlx::core::metal::Device
 
+ + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::metal
 
+ + + +

+Typedefs

using mlx::core::metal::MTLFCList
 
+ + + + + +

+Functions

std::string mlx::core::metal::get_colocated_mtllib_path (const std::string &lib_name)
 
Devicemlx::core::metal::device (mlx::core::Device)
 
+
+ + + + diff --git a/docs/build/html/backend_2metal_2device_8h_source.html b/docs/build/html/backend_2metal_2device_8h_source.html new file mode 100644 index 000000000..cac1cfce8 --- /dev/null +++ b/docs/build/html/backend_2metal_2device_8h_source.html @@ -0,0 +1,389 @@ + + + + + + + +MLX: mlx/backend/metal/device.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
device.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <Metal/Metal.hpp>
+
6#include <functional>
+
7#include <mutex>
+
8#include <string>
+
9#include <unordered_map>
+
10#include <unordered_set>
+
11
+
12#include <dlfcn.h>
+
13#include <filesystem>
+
14
+
15#include "mlx/array.h"
+
16#include "mlx/device.h"
+
17
+
18namespace fs = std::filesystem;
+
19
+
20namespace mlx::core::metal {
+
21
+
+
22inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
+
23 Dl_info info;
+
24 std::string mtllib_path;
+
25 std::string lib_ext = lib_name + ".metallib";
+
26
+
27 int success = dladdr((void*)get_colocated_mtllib_path, &info);
+
28 if (success) {
+
29 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
+
30 mtllib_path = mtllib.c_str();
+
31 }
+
32
+
33 return mtllib_path;
+
34}
+
+
35
+
36using MTLFCList =
+
37 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
+
38
+
+ +
+
40 CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
+
41 enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
+
42 enc->retain();
+
43 };
+
+ + +
46
+
+ +
+ +
49 enc.concurrent = true;
+
50 }
+
+
+ +
52 enc.concurrent = false;
+
53 enc.outputs.insert(
+
54 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
+
55 enc.concurrent_outputs.clear();
+
56 }
+
+
57
+
58 private:
+
59 CommandEncoder& enc;
+
60 };
+
+
61
+
+
62 MTL::ComputeCommandEncoder* operator->() {
+
63 return enc;
+
64 }
+
+
65
+
+
66 void set_input_array(const array& a, int idx, int offset = 0) {
+
67 auto r_buf =
+
68 static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
+
69 if (auto it = outputs.find(r_buf); it != outputs.end()) {
+
70 // Insert a barrier
+
71 enc->memoryBarrier(&r_buf, 1);
+
72
+
73 // Remove the output
+
74 outputs.erase(it);
+
75 }
+
76 auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
+
77 auto base_offset = a.data<char>() -
+
78 static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
+
79 base_offset += offset;
+
80 enc->setBuffer(a_buf, base_offset, idx);
+
81 }
+
+
82
+
+
83 void set_output_array(array& a, int idx, int offset = 0) {
+
84 // Add barriers before adding the output to the output set
+
85 set_input_array(a, idx, offset);
+
86 auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
+
87 if (concurrent) {
+
88 concurrent_outputs.insert(buf);
+
89 } else {
+
90 outputs.insert(buf);
+
91 }
+
92 }
+
+
93
+
94 void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
+
95 void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
+
96
+
+ +
98 return ConcurrentContext(*this);
+
99 }
+
+
100
+
+ +
102 enc->endEncoding();
+
103 enc->release();
+
104 }
+
+
105
+
106 private:
+
107 void maybe_split();
+
108
+
109 int num_dispatches{0};
+
110 MTL::CommandBuffer* cbuf;
+
111 MTL::ComputeCommandEncoder* enc;
+
112 bool concurrent{false};
+
113 std::unordered_set<MTL::Resource*> outputs;
+
114 std::unordered_set<MTL::Resource*> concurrent_outputs;
+
115};
+
+
116
+
+
117class Device {
+
118 public:
+ +
120 Device(const Device&) = delete;
+
121 Device& operator=(const Device&) = delete;
+ +
123
+
+
124 MTL::Device* mtl_device() {
+
125 return device_;
+
126 };
+
+
127
+
128 void new_queue(int index);
+
129 MTL::CommandBuffer* get_command_buffer(int index);
+ + +
132 void commit_command_buffer(int index);
+ +
134 void end_encoding(int index);
+
135
+ +
137 const std::string& lib_name,
+
138 const std::string& lib_path);
+ +
140 const std::string& lib_name,
+
141 const std::function<std::string(const std::string&)>& lib_path_func =
+ +
143
+
144 MTL::Library* get_library(const std::string& name);
+
145
+
146 MTL::Library* get_library(
+
147 const std::string& name,
+
148 const std::string& source_string,
+
149 bool cache = true);
+
150
+
151 MTL::Library* get_library(
+
152 const std::string& name,
+
153 const MTL::StitchedLibraryDescriptor* desc,
+
154 bool cache = true);
+
155
+
156 MTL::Function* get_function(
+
157 const std::string& base_name,
+
158 MTL::Library* mtl_lib,
+
159 const std::string& specialized_name = "",
+
160 const MTLFCList& func_consts = {});
+
161
+
162 MTL::Function* get_function(
+
163 const std::string& base_name,
+
164 const std::string& lib_name = "mlx",
+
165 const std::string& specialized_name = "",
+
166 const MTLFCList& func_consts = {});
+
167
+
168 MTL::ComputePipelineState* get_kernel(
+
169 const std::string& base_name,
+
170 MTL::Library* mtl_lib,
+
171 const std::string& hash_name = "",
+
172 const MTLFCList& func_consts = {},
+
173 const std::vector<MTL::Function*>& linked_functions = {});
+
174
+
175 MTL::ComputePipelineState* get_kernel(
+
176 const std::string& base_name,
+
177 const std::string& lib_name = "mlx",
+
178 const std::string& hash_name = "",
+
179 const MTLFCList& func_consts = {},
+
180 const std::vector<MTL::Function*>& linked_functions = {});
+
181
+
182 MTL::ArgumentEncoder* argument_encoder(
+
183 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
+
184
+
185 private:
+
186 MTL::Library* get_library_cache_(const std::string& name);
+
187
+
188 MTL::Library* get_library_(const std::string& source_string);
+
189 MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
+
190
+
191 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
+
192
+
193 MTL::Function* get_function_(
+
194 const std::string& name,
+
195 const std::string& specialized_name,
+
196 const MTLFCList& func_consts,
+
197 MTL::Library* mtl_lib);
+
198
+
199 MTL::LinkedFunctions* get_linked_functions_(
+
200 const std::vector<MTL::Function*>& funcs);
+
201
+
202 MTL::ComputePipelineState* get_kernel_(
+
203 const std::string& name,
+
204 const MTL::Function* mtl_function);
+
205
+
206 MTL::ComputePipelineState* get_kernel_(
+
207 const std::string& name,
+
208 const MTL::Function* mtl_function,
+
209 const MTL::LinkedFunctions* linked_functions);
+
210
+
211 MTL::Device* device_;
+
212 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
+
213 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
+
214 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
+
215 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
+
216 std::unordered_map<std::string, MTL::Library*> library_map_;
+
217 std::mutex mtx_;
+
218};
+
+
219
+ +
221
+
222} // namespace mlx::core::metal
+ +
MTL::Buffer * buf
Definition allocator.h:38
+
const void * ptr() const
Definition allocator.h:23
+
Definition array.h:20
+
T * data()
Definition array.h:313
+
allocator::Buffer & buffer()
Definition array.h:299
+
Definition device.h:117
+
int get_command_buffer_ops(int index)
+
MTL::Device * mtl_device()
Definition device.h:124
+
void register_library(const std::string &lib_name, const std::string &lib_path)
+ +
MTL::CommandBuffer * get_command_buffer(int index)
+
void end_encoding(int index)
+
MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
+
void register_library(const std::string &lib_name, const std::function< std::string(const std::string &)> &lib_path_func=get_colocated_mtllib_path)
+
MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
+
void increment_command_buffer_ops(int index)
+
void new_queue(int index)
+
MTL::Library * get_library(const std::string &name)
+
MTL::Library * get_library(const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)
+
void commit_command_buffer(int index)
+
MTL::Library * get_library(const std::string &name, const std::string &source_string, bool cache=true)
+
MTL::Function * get_function(const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})
+
Device(const Device &)=delete
+
MTL::Function * get_function(const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})
+
Device & operator=(const Device &)=delete
+ +
MTL::ComputePipelineState * get_kernel(const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
+
CommandEncoder & get_command_encoder(int index)
+ +
Definition allocator.h:12
+
std::string get_colocated_mtllib_path(const std::string &lib_name)
Definition device.h:22
+
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:36
+
Device & device(mlx::core::Device)
+
Definition device.h:7
+ + +
ConcurrentContext(CommandEncoder &enc)
Definition device.h:48
+
Definition device.h:39
+
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)
+
CommandEncoder(MTL::CommandBuffer *cbuf)
Definition device.h:40
+
CommandEncoder & operator=(const CommandEncoder &)=delete
+
ConcurrentContext start_concurrent()
Definition device.h:97
+
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)
+
void set_input_array(const array &a, int idx, int offset=0)
Definition device.h:66
+
~CommandEncoder()
Definition device.h:101
+
MTL::ComputeCommandEncoder * operator->()
Definition device.h:62
+
CommandEncoder(const CommandEncoder &)=delete
+
void set_output_array(array &a, int idx, int offset=0)
Definition device.h:83
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2bf16_8h.html b/docs/build/html/backend_2metal_2kernels_2bf16_8h.html new file mode 100644 index 000000000..ccea11b13 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2bf16_8h.html @@ -0,0 +1,10952 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/bf16.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Namespaces | +Macros | +Typedefs | +Functions | +Variables
+
bf16.h File Reference
+
+
+
#include <metal_stdlib>
+#include "mlx/backend/metal/kernels/bf16_math.h"
+
+

Go to the source code of this file.

+ + + + + + + + +

+Classes

struct  _MLX_BFloat16
 
struct  _MLX_BFloat16::bits_to_bfloat_struct
 
struct  metal::_numeric_limits_impl< bfloat16_t >
 
+ + + +

+Namespaces

namespace  metal
 
+ + + + + + + + + + + + + + + + + + + +

+Macros

#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype)
 
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype)
 
#define bfloat_binop(_op_, _operator_)
 
#define bfloat_compop(__op__, __operator__)
 
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space)
 
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype)
 
#define bfloat_inplace_op(itype)
 
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space)
 
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__)
 
+ + + +

+Typedefs

typedef struct _MLX_BFloat16 bfloat16_t
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

constexpr METAL_FUNC uint16_t float_to_bfloat_bits (float x)
 
constexpr METAL_FUNC float bfloat_bits_to_float (uint16_t x)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 x)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator+ (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC float operator+ (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator+ (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC float operator+ (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator- (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC float operator- (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator- (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC float operator- (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator* (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC float operator* (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator* (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC float operator* (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator/ (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC float operator/ (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator/ (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC float operator/ (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator> (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator> (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator> (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator> (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator> (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator> (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator< (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator< (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator< (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator< (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator< (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator< (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator>= (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator>= (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator>= (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator>= (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator>= (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator>= (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator<= (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator<= (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator<= (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator<= (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator<= (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator<= (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator== (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator== (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator== (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator== (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator== (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator== (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator!= (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator!= (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator!= (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator!= (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator!= (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator!= (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC device float & operator+= (device float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC thread float & operator+= (thread float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC threadgroup float & operator+= (threadgroup float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC device float & operator-= (device float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC thread float & operator-= (thread float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC threadgroup float & operator-= (threadgroup float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC device float & operator*= (device float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC thread float & operator*= (thread float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC threadgroup float & operator*= (threadgroup float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC device float & operator/= (device float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC thread float & operator/= (thread float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC threadgroup float & operator/= (threadgroup float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC device half & operator+= (device half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC thread half & operator+= (thread half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC threadgroup half & operator+= (threadgroup half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC device half & operator-= (device half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC thread half & operator-= (thread half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC threadgroup half & operator-= (threadgroup half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC device half & operator*= (device half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC thread half & operator*= (thread half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC threadgroup half & operator*= (threadgroup half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC device half & operator/= (device half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC thread half & operator/= (thread half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC threadgroup half & operator/= (threadgroup half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC device int16_t & operator+= (device int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC thread int16_t & operator+= (thread int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC threadgroup int16_t & operator+= (threadgroup int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC device int16_t & operator-= (device int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC thread int16_t & operator-= (thread int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC threadgroup int16_t & operator-= (threadgroup int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC device int16_t & operator*= (device int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC thread int16_t & operator*= (thread int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC threadgroup int16_t & operator*= (threadgroup int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC device int16_t & operator/= (device int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC thread int16_t & operator/= (thread int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC threadgroup int16_t & operator/= (threadgroup int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC device int32_t & operator+= (device int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC thread int32_t & operator+= (thread int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC threadgroup int32_t & operator+= (threadgroup int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC device int32_t & operator-= (device int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC thread int32_t & operator-= (thread int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC threadgroup int32_t & operator-= (threadgroup int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC device int32_t & operator*= (device int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC thread int32_t & operator*= (thread int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC threadgroup int32_t & operator*= (threadgroup int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC device int32_t & operator/= (device int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC thread int32_t & operator/= (thread int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC threadgroup int32_t & operator/= (threadgroup int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC device int64_t & operator+= (device int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC thread int64_t & operator+= (thread int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC threadgroup int64_t & operator+= (threadgroup int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC device int64_t & operator-= (device int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC thread int64_t & operator-= (thread int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC threadgroup int64_t & operator-= (threadgroup int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC device int64_t & operator*= (device int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC thread int64_t & operator*= (thread int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC threadgroup int64_t & operator*= (threadgroup int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC device int64_t & operator/= (device int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC thread int64_t & operator/= (thread int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC threadgroup int64_t & operator/= (threadgroup int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC device uint16_t & operator+= (device uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC thread uint16_t & operator+= (thread uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC threadgroup uint16_t & operator+= (threadgroup uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC device uint16_t & operator-= (device uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC thread uint16_t & operator-= (thread uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC threadgroup uint16_t & operator-= (threadgroup uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC device uint16_t & operator*= (device uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC thread uint16_t & operator*= (thread uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC threadgroup uint16_t & operator*= (threadgroup uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC device uint16_t & operator/= (device uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC thread uint16_t & operator/= (thread uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC threadgroup uint16_t & operator/= (threadgroup uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC device uint32_t & operator+= (device uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC thread uint32_t & operator+= (thread uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC threadgroup uint32_t & operator+= (threadgroup uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC device uint32_t & operator-= (device uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC thread uint32_t & operator-= (thread uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC threadgroup uint32_t & operator-= (threadgroup uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC device uint32_t & operator*= (device uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC thread uint32_t & operator*= (thread uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC threadgroup uint32_t & operator*= (threadgroup uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC device uint32_t & operator/= (device uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC thread uint32_t & operator/= (thread uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC threadgroup uint32_t & operator/= (threadgroup uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC device uint64_t & operator+= (device uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC thread uint64_t & operator+= (thread uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC threadgroup uint64_t & operator+= (threadgroup uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC device uint64_t & operator-= (device uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC thread uint64_t & operator-= (thread uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC threadgroup uint64_t & operator-= (threadgroup uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC device uint64_t & operator*= (device uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC thread uint64_t & operator*= (thread uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC threadgroup uint64_t & operator*= (threadgroup uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC device uint64_t & operator/= (device uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC thread uint64_t & operator/= (thread uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC threadgroup uint64_t & operator/= (threadgroup uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
METAL_FUNC bool metal::isnan (_MLX_BFloat16 x)
 
+ + + + + + + +

+Variables

template<typename T >
static constexpr constant bool can_convert_to_bfloat
 
template<typename T >
static constexpr constant bool can_convert_from_bfloat
 
+

Macro Definition Documentation

+ +

◆ bfloat_binop

+ +
+
+ + + + + + + + + + + +
#define bfloat_binop( _op_,
_operator_ )
+
+Value:
+
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
+
bfloat_binop_helper(_op_, _operator_, float, float, float); \
+
bfloat_binop_helper(_op_, _operator_, float, half, float); \
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
+
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype)
Definition bf16.h:141
+
Definition bf16.h:54
+
+
+
+ +

◆ bfloat_binop_base

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
#define bfloat_binop_base( __op__,
__operator__,
otype,
atype,
btype,
ctype )
+
+Value:
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
}
+
+
+
+ +

◆ bfloat_binop_helper

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
#define bfloat_binop_helper( __op__,
__operator__,
otype,
itype,
ctype )
+
+Value:
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
} \
+
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
}
+
+
+
+ +

◆ bfloat_compop

+ +
+
+ + + + + + + + + + + +
#define bfloat_compop( __op__,
__operator__ )
+
+Value:
+
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
+
+
+
+ +

◆ bfloat_inplace_op

+ +
+
+ + + + + + + +
#define bfloat_inplace_op( itype)
+
+Value:
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
+
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
+
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
+
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
+
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype)
Definition bf16.h:209
+
+
+
+ +

◆ bfloat_inplace_op_addr_space_helper [1/2]

+ +
+
+ + + + + + + + + + + +
#define bfloat_inplace_op_addr_space_helper( __op__,
__operator__ )
+
+Value:
bfloat_inplace_op_helper(__op__, __operator__, device); \
+
bfloat_inplace_op_helper(__op__, __operator__, thread); \
+
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
+
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space)
Definition bf16.h:197
+
+
+
+ +

◆ bfloat_inplace_op_addr_space_helper [2/2]

+ +
+
+ + + + + + + + + + + + + + + + +
#define bfloat_inplace_op_addr_space_helper( __op__,
__operator__,
itype )
+
+Value:
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
+
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
+
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
+
+
+
+ +

◆ bfloat_inplace_op_helper [1/2]

+ +
+
+ + + + + + + + + + + + + + + + +
#define bfloat_inplace_op_helper( __op__,
__operator__,
addr_space )
+
+Value:
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
+
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
+
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
return lhs; \
+
}
+
+
+
+ +

◆ bfloat_inplace_op_helper [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
#define bfloat_inplace_op_helper( __op__,
__operator__,
itype,
addr_space )
+
+Value:
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
+
addr_space _MLX_BFloat16& lhs, itype rhs) { \
+
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
return lhs; \
+
} \
+
constexpr METAL_FUNC addr_space itype& __operator__( \
+
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
+
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
return lhs; \
+
}
+
+
+
+

Typedef Documentation

+ +

◆ bfloat16_t

+ +
+
+ + + + +
typedef struct _MLX_BFloat16 bfloat16_t
+
+ +
+
+

Function Documentation

+ +

◆ bfloat_bits_to_float()

+ +
+
+ + + + + +
+ + + + + + + +
constexpr METAL_FUNC float bfloat_bits_to_float (uint16_t x)
+
+constexpr
+
+ +
+
+ +

◆ float_to_bfloat_bits()

+ +
+
+ + + + + +
+ + + + + + + +
constexpr METAL_FUNC uint16_t float_to_bfloat_bits (float x)
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator* (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator* (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator* (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator* (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [1/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [2/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [3/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [4/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [5/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [6/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [7/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [8/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [9/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [10/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device float & operator*= (device float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [11/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device half & operator*= (device half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [12/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int16_t & operator*= (device int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [13/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int32_t & operator*= (device int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [14/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int64_t & operator*= (device int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [15/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint16_t & operator*= (device uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [16/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint32_t & operator*= (device uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [17/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint64_t & operator*= (device uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [18/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [19/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [20/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [21/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [22/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [23/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [24/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [25/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [26/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [27/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread float & operator*= (thread float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [28/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread half & operator*= (thread half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [29/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int16_t & operator*= (thread int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [30/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int32_t & operator*= (thread int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [31/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int64_t & operator*= (thread int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [32/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint16_t & operator*= (thread uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [33/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint32_t & operator*= (thread uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [34/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint64_t & operator*= (thread uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [35/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [36/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [37/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [38/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [39/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [40/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [41/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [42/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [43/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [44/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup float & operator*= (threadgroup float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [45/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup half & operator*= (threadgroup half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [46/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int16_t & operator*= (threadgroup int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [47/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int32_t & operator*= (threadgroup int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [48/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int64_t & operator*= (threadgroup int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [49/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint16_t & operator*= (threadgroup uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [50/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint32_t & operator*= (threadgroup uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [51/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint64_t & operator*= (threadgroup uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator+ (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator+ (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator+ (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator+ (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [1/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [2/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [3/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [4/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [5/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [6/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [7/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [8/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [9/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [10/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device float & operator+= (device float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [11/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device half & operator+= (device half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [12/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int16_t & operator+= (device int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [13/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int32_t & operator+= (device int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [14/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int64_t & operator+= (device int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [15/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint16_t & operator+= (device uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [16/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint32_t & operator+= (device uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [17/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint64_t & operator+= (device uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [18/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [19/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [20/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [21/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [22/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [23/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [24/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [25/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [26/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [27/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread float & operator+= (thread float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [28/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread half & operator+= (thread half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [29/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int16_t & operator+= (thread int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [30/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int32_t & operator+= (thread int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [31/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int64_t & operator+= (thread int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [32/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint16_t & operator+= (thread uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [33/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint32_t & operator+= (thread uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [34/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint64_t & operator+= (thread uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [35/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [36/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [37/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [38/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [39/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [40/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [41/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [42/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [43/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [44/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup float & operator+= (threadgroup float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [45/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup half & operator+= (threadgroup half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [46/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int16_t & operator+= (threadgroup int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [47/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int32_t & operator+= (threadgroup int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [48/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int64_t & operator+= (threadgroup int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [49/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint16_t & operator+= (threadgroup uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [50/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint32_t & operator+= (threadgroup uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [51/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint64_t & operator+= (threadgroup uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [1/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [2/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator- (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [3/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator- (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [4/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [5/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [6/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [7/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [8/14]

+ +
+
+ + + + + +
+ + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 x)
+
+constexpr
+
+ +
+
+ +

◆ operator-() [9/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator- (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [10/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator- (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [11/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [12/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [13/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [14/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [1/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [2/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [3/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [4/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [5/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [6/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [7/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [8/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [9/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [10/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device float & operator-= (device float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [11/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device half & operator-= (device half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [12/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int16_t & operator-= (device int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [13/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int32_t & operator-= (device int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [14/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int64_t & operator-= (device int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [15/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint16_t & operator-= (device uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [16/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint32_t & operator-= (device uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [17/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint64_t & operator-= (device uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [18/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [19/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [20/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [21/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [22/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [23/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [24/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [25/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [26/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [27/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread float & operator-= (thread float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [28/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread half & operator-= (thread half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [29/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int16_t & operator-= (thread int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [30/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int32_t & operator-= (thread int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [31/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int64_t & operator-= (thread int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [32/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint16_t & operator-= (thread uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [33/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint32_t & operator-= (thread uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [34/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint64_t & operator-= (thread uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [35/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [36/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [37/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [38/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [39/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [40/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [41/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [42/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [43/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [44/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup float & operator-= (threadgroup float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [45/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup half & operator-= (threadgroup half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [46/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int16_t & operator-= (threadgroup int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [47/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int32_t & operator-= (threadgroup int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [48/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int64_t & operator-= (threadgroup int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [49/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint16_t & operator-= (threadgroup uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [50/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint32_t & operator-= (threadgroup uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [51/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint64_t & operator-= (threadgroup uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator/ (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator/ (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator/ (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator/ (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [1/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [2/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [3/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [4/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [5/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [6/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [7/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [8/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [9/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [10/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device float & operator/= (device float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [11/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device half & operator/= (device half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [12/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int16_t & operator/= (device int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [13/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int32_t & operator/= (device int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [14/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int64_t & operator/= (device int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [15/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint16_t & operator/= (device uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [16/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint32_t & operator/= (device uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [17/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint64_t & operator/= (device uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [18/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [19/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [20/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [21/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [22/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [23/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [24/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [25/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [26/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [27/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread float & operator/= (thread float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [28/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread half & operator/= (thread half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [29/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int16_t & operator/= (thread int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [30/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int32_t & operator/= (thread int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [31/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int64_t & operator/= (thread int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [32/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint16_t & operator/= (thread uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [33/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint32_t & operator/= (thread uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [34/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint64_t & operator/= (thread uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [35/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [36/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [37/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [38/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [39/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [40/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [41/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [42/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [43/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [44/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup float & operator/= (threadgroup float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [45/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup half & operator/= (threadgroup half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [46/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int16_t & operator/= (threadgroup int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [47/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int32_t & operator/= (threadgroup int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [48/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int64_t & operator/= (threadgroup int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [49/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint16_t & operator/= (threadgroup uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [50/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint32_t & operator/= (threadgroup uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [51/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint64_t & operator/= (threadgroup uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+

Variable Documentation

+ +

◆ can_convert_from_bfloat

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool can_convert_from_bfloat
+
+staticconstexpr
+
+Initial value:
=
+
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>
+
+
+
+ +

◆ can_convert_to_bfloat

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool can_convert_to_bfloat
+
+staticconstexpr
+
+Initial value:
=
+
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>
+
+
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2bf16_8h_source.html b/docs/build/html/backend_2metal_2kernels_2bf16_8h_source.html new file mode 100644 index 000000000..83ebc44f4 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2bf16_8h_source.html @@ -0,0 +1,489 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/bf16.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
bf16.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_stdlib>
+
6
+
7using namespace metal;
+
8
+
9#if defined(__HAVE_BFLOAT__)
+
10
+
11typedef bfloat bfloat16_t;
+
12
+
13#else
+
14
+
16// Helpers
+
18
+
+
19constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
+
20 // Check for nan
+
21 if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
+
22 _fp_encoding_traits<float>::inf_mask) {
+
23 return uint16_t(as_type<uint32_t>(0x7FC0));
+
24 }
+
25 // Take bits
+
26 uint32_t float_bits = as_type<uint32_t>(x);
+
27
+
28 // Round to nearest even
+
29 float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
+
30
+
31 // Take upper 16 bits
+
32 return float_bits >> 16;
+
33}
+
+
34
+
+
35constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
+
36 // Upper 16 bits are the data and lower 16 bits are 0s
+
37 return as_type<float>((uint32_t)x << 16);
+
38}
+
+
39
+
40struct _MLX_BFloat16;
+
41
+
42template <typename T>
+
43static constexpr constant bool can_convert_to_bfloat =
+
44 !is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
+
45
+
46template <typename T>
+
47static constexpr constant bool can_convert_from_bfloat =
+
48 !is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
+
49
+
51// Bfloat struct
+
53
+
+ +
56 // Constructors
+
57 uint16_t bits_;
+
58 _MLX_BFloat16() thread = default;
+
59 _MLX_BFloat16() threadgroup = default;
+
60 _MLX_BFloat16() device = default;
+
61 _MLX_BFloat16() constant = default;
+
62
+ +
+
64 static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
+
65 return bits_to_bfloat_struct();
+
66 }
+
+
+
67 constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
+
68 : bits_(bits) {}
+
+
69
+
71 // Conversions to bfloat
+
72
+
73 template <
+
74 typename T,
+
75 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
+
+
76 constexpr METAL_FUNC _MLX_BFloat16(T x) thread
+
77 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
+
+
78
+
79 template <
+
80 typename T,
+
81 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
+
+
82 constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
+
83 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
+
+
84
+
85 template <
+
86 typename T,
+
87 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
+
+
88 constexpr METAL_FUNC _MLX_BFloat16(T x) device
+
89 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
+
+
90
+
91 template <
+
92 typename T,
+
93 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
+
+
94 constexpr METAL_FUNC _MLX_BFloat16(T x) constant
+
95 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
+
+
96
+
98 // Conversions from bfloat
+
99
+
100 template <
+
101 typename T,
+
102 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
+
+
103 constexpr METAL_FUNC operator T() const thread {
+
104 return static_cast<T>(bfloat_bits_to_float(bits_));
+
105 }
+
+
106
+
107 template <
+
108 typename T,
+
109 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
+
+
110 constexpr METAL_FUNC operator T() const threadgroup {
+
111 return static_cast<T>(bfloat_bits_to_float(bits_));
+
112 }
+
+
113
+
114 template <
+
115 typename T,
+
116 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
+
+
117 constexpr METAL_FUNC operator T() const device {
+
118 return static_cast<T>(bfloat_bits_to_float(bits_));
+
119 }
+
+
120
+
121 template <
+
122 typename T,
+
123 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
+
+
124 constexpr METAL_FUNC operator T() const constant {
+
125 return static_cast<T>(bfloat_bits_to_float(bits_));
+
126 }
+
+
127};
+
+
128
+
130// Bfloat operators
+
132
+
134// Unary ops
+
+
135constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
+
136 return -static_cast<float>(x);
+
137}
+
+
138
+
140// Binary operators
+
+
141#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
+
142 constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
+
143 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
144 }
+
+
145
+
+
146#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
+
147 constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
+
148 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
149 } \
+
150 constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
+
151 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
152 }
+
+
153
+
155// Arithmetic Operators
+
+
156#define bfloat_binop(_op_, _operator_) \
+
157 bfloat_binop_base( \
+
158 _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
+
159 bfloat_binop_helper(_op_, _operator_, float, float, float); \
+
160 bfloat_binop_helper(_op_, _operator_, float, half, float); \
+
161 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
+
162 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
+
163 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
+
164 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
+
+
165
+
166bfloat_binop(+, operator+);
+
167bfloat_binop(-, operator-);
+
168bfloat_binop(*, operator*);
+
169bfloat_binop(/, operator/);
+
170
+
172// Comparison ops
+
+
173#define bfloat_compop(__op__, __operator__) \
+
174 bfloat_binop_base( \
+
175 __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
+
176 bfloat_binop_helper(__op__, __operator__, bool, float, float); \
+
177 bfloat_binop_helper(__op__, __operator__, bool, half, float); \
+
178 bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
+
179 bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
+
180 bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
+
181 bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
+
+
182
+
183bfloat_compop(>, operator>);
+
184bfloat_compop(<, operator<);
+
185bfloat_compop(>=, operator>=);
+
186bfloat_compop(<=, operator<=);
+
187bfloat_compop(==, operator==);
+
188bfloat_compop(!=, operator!=);
+
189
+
190#undef bfloat_compop
+
191#undef bfloat_binop_base
+
192#undef bfloat_binop_helper
+
193#undef bfloat_binop
+
194
+
196// Inplace Operators
+
+
197#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
+
198 constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
+
199 addr_space _MLX_BFloat16& lhs, itype rhs) { \
+
200 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
201 return lhs; \
+
202 } \
+
203 constexpr METAL_FUNC addr_space itype& __operator__( \
+
204 addr_space itype& lhs, _MLX_BFloat16 rhs) { \
+
205 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
206 return lhs; \
+
207 }
+
+
208
+
+
209#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
+
210 bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
+
211 bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
+
212 bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
+
+
213
+
+
214#define bfloat_inplace_op(itype) \
+
215 bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
+
216 bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
+
217 bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
+
218 bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
+
+
219
+ + + + + + + + +
228
+
229#undef bfloat_inplace_op_helper
+
230#undef bfloat_inplace_op_addr_space_helper
+
231#undef bfloat_inplace_op
+
232
+
233#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
+
234 constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
+
235 addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
+
236 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
237 return lhs; \
+
238 }
+
239
+
240#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
+
241 bfloat_inplace_op_helper(__op__, __operator__, device); \
+
242 bfloat_inplace_op_helper(__op__, __operator__, thread); \
+
243 bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
+
244
+ + + + +
249
+
250#undef bfloat_inplace_op_helper
+
251#undef bfloat_inplace_op_addr_space_helper
+
252
+
254// Bfloat typedef
+
256
+ +
258
+
260// Bfloat numeric limits
+
262
+
263#pragma METAL internals : enable
+
264
+
+
265namespace metal {
+
266
+
267template <>
+
+
268struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
+
269 static constexpr constant int digits = 8;
+
270 static constexpr constant int digits10 = 2;
+
271 static constexpr constant int max_digits10 = 4;
+
272 static constexpr constant int radix = 2;
+
273 static constexpr constant int min_exponent = -125;
+
274 static constexpr constant int min_exponent10 = -37;
+
275 static constexpr constant int max_exponent = 128;
+
276 static constexpr constant int max_exponent10 = 38;
+
277
+
+
278 static constexpr bfloat16_t min() {
+ +
280 }
+
+
+
281 static constexpr bfloat16_t lowest() {
+ +
283 }
+
+
+
284 static constexpr bfloat16_t max() {
+ +
286 }
+
+
+
287 static constexpr bfloat16_t epsilon() {
+ +
289 }
+
+
+
290 static constexpr bfloat16_t round_error() {
+ +
292 }
+
+
+
293 static constexpr bfloat16_t infinity() {
+ +
295 }
+
+
+
296 static constexpr bfloat16_t quiet_NaN() {
+ +
298 }
+
+
+
299 static constexpr bfloat16_t signaling_NaN() {
+ +
301 }
+
+
+
302 static constexpr bfloat16_t denorm_min() {
+ +
304 }
+
+
305};
+
+
306
+
+
307METAL_FUNC bool isnan(_MLX_BFloat16 x) {
+
308 return x != x;
+
309}
+
+
310
+
311} // namespace metal
+
+
312
+
313#pragma METAL internals : disable
+
314
+
315#endif // defined(__HAVE_BFLOAT__)
+
316
+ +
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x)
Definition bf16.h:19
+
#define bfloat_compop(__op__, __operator__)
Definition bf16.h:173
+
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x)
Definition bf16.h:35
+
#define bfloat_inplace_op(itype)
Definition bf16.h:214
+
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x)
Definition bf16.h:135
+
#define bfloat_binop(_op_, _operator_)
Definition bf16.h:156
+
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:257
+
static constexpr constant bool can_convert_from_bfloat
Definition bf16.h:47
+
static constexpr constant bool can_convert_to_bfloat
Definition bf16.h:43
+
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype)
Definition bf16.h:209
+ +
Definition bf16.h:265
+
METAL_FUNC bool isnan(_MLX_BFloat16 x)
Definition bf16.h:307
+ +
Definition bf16.h:54
+
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
Definition bf16.h:76
+
uint16_t bits_
Definition bf16.h:57
+
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
Definition bf16.h:67
+
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat()
Definition bf16.h:64
+
_MLX_BFloat16() thread=default
+
constexpr METAL_FUNC _MLX_BFloat16(T x) device
Definition bf16.h:88
+
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
Definition bf16.h:82
+
constexpr METAL_FUNC _MLX_BFloat16(T x) const ant
Definition bf16.h:94
+
static constexpr bfloat16_t infinity()
Definition bf16.h:293
+
static constexpr bfloat16_t denorm_min()
Definition bf16.h:302
+
static constexpr bfloat16_t max()
Definition bf16.h:284
+
static constexpr bfloat16_t epsilon()
Definition bf16.h:287
+
static constexpr bfloat16_t signaling_NaN()
Definition bf16.h:299
+
static constexpr bfloat16_t min()
Definition bf16.h:278
+
static constexpr bfloat16_t lowest()
Definition bf16.h:281
+
static constexpr bfloat16_t quiet_NaN()
Definition bf16.h:296
+
static constexpr bfloat16_t round_error()
Definition bf16.h:290
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2complex_8h.html b/docs/build/html/backend_2metal_2kernels_2complex_8h.html new file mode 100644 index 000000000..f45ed8b0a --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2complex_8h.html @@ -0,0 +1,504 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/complex.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Functions | +Variables
+
complex.h File Reference
+
+
+
#include <metal_stdlib>
+
+

Go to the source code of this file.

+ + + + +

+Classes

struct  complex64_t
 
+ + + + + + + + + + + + + + + + + + + + + + + +

+Functions

constexpr complex64_t operator- (complex64_t x)
 
constexpr bool operator>= (complex64_t a, complex64_t b)
 
constexpr bool operator> (complex64_t a, complex64_t b)
 
constexpr bool operator<= (complex64_t a, complex64_t b)
 
constexpr bool operator< (complex64_t a, complex64_t b)
 
constexpr bool operator== (complex64_t a, complex64_t b)
 
constexpr complex64_t operator+ (complex64_t a, complex64_t b)
 
constexpr complex64_t operator- (complex64_t a, complex64_t b)
 
constexpr complex64_t operator* (complex64_t a, complex64_t b)
 
constexpr complex64_t operator/ (complex64_t a, complex64_t b)
 
constexpr complex64_t operator% (complex64_t a, complex64_t b)
 
+ + + + + + + +

+Variables

template<typename T >
static constexpr constant bool can_convert_to_complex64
 
template<typename T >
static constexpr constant bool can_convert_from_complex64
 
+

Function Documentation

+ +

◆ operator%()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator% (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator*()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator* (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator+()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator+ (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator- (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
constexpr complex64_t operator- (complex64_t x)
+
+constexpr
+
+ +
+
+ +

◆ operator/()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator/ (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator<()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator< (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator<=()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator<= (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator==()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator== (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator>()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator> (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator>=()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator>= (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+

Variable Documentation

+ +

◆ can_convert_from_complex64

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool can_convert_from_complex64
+
+staticconstexpr
+
+Initial value:
=
+
!is_same_v<T, complex64_t> &&
+
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>)
+
+
+
+ +

◆ can_convert_to_complex64

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool can_convert_to_complex64
+
+staticconstexpr
+
+Initial value:
=
+
!is_same_v<T, complex64_t> && is_convertible_v<T, float>
+
+
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2complex_8h_source.html b/docs/build/html/backend_2metal_2kernels_2complex_8h_source.html new file mode 100644 index 000000000..c47bbd109 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2complex_8h_source.html @@ -0,0 +1,276 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/complex.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
complex.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_stdlib>
+
6
+
7using namespace metal;
+
8
+
9struct complex64_t;
+
10
+
11template <typename T>
+
12static constexpr constant bool can_convert_to_complex64 =
+
13 !is_same_v<T, complex64_t> && is_convertible_v<T, float>;
+
14
+
15template <typename T>
+
16static constexpr constant bool can_convert_from_complex64 =
+
17 !is_same_v<T, complex64_t> &&
+
18 (is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
+
19
+
+ +
21 float real;
+
22 float imag;
+
23
+
24 // Constructors
+
25 constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
+
26
+
27 // Conversions to complex64_t
+
28 template <
+
29 typename T,
+
30 typename = typename enable_if<can_convert_to_complex64<T>>::type>
+
31 constexpr complex64_t(T x) thread : real(x), imag(0) {}
+
32
+
33 template <
+
34 typename T,
+
35 typename = typename enable_if<can_convert_to_complex64<T>>::type>
+
36 constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
+
37
+
38 template <
+
39 typename T,
+
40 typename = typename enable_if<can_convert_to_complex64<T>>::type>
+
41 constexpr complex64_t(T x) device : real(x), imag(0) {}
+
42
+
43 template <
+
44 typename T,
+
45 typename = typename enable_if<can_convert_to_complex64<T>>::type>
+
46 constexpr complex64_t(T x) constant : real(x), imag(0) {}
+
47
+
48 // Conversions from complex64_t
+
49 template <
+
50 typename T,
+
51 typename = typename enable_if<can_convert_from_complex64<T>>::type>
+
+
52 constexpr operator T() const thread {
+
53 return static_cast<T>(real);
+
54 }
+
+
55
+
56 template <
+
57 typename T,
+
58 typename = typename enable_if<can_convert_from_complex64<T>>::type>
+
+
59 constexpr operator T() const threadgroup {
+
60 return static_cast<T>(real);
+
61 }
+
+
62
+
63 template <
+
64 typename T,
+
65 typename = typename enable_if<can_convert_from_complex64<T>>::type>
+
+
66 constexpr operator T() const device {
+
67 return static_cast<T>(real);
+
68 }
+
+
69
+
70 template <
+
71 typename T,
+
72 typename = typename enable_if<can_convert_from_complex64<T>>::type>
+
+
73 constexpr operator T() const constant {
+
74 return static_cast<T>(real);
+
75 }
+
+
76};
+
+
77
+
+ +
79 return {-x.real, -x.imag};
+
80}
+
+
81
+
+
82constexpr bool operator>=(complex64_t a, complex64_t b) {
+
83 return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
+
84}
+
+
85
+
+
86constexpr bool operator>(complex64_t a, complex64_t b) {
+
87 return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
+
88}
+
+
89
+
+
90constexpr bool operator<=(complex64_t a, complex64_t b) {
+
91 return operator>=(b, a);
+
92}
+
+
93
+
+
94constexpr bool operator<(complex64_t a, complex64_t b) {
+
95 return operator>(b, a);
+
96}
+
+
97
+
+
98constexpr bool operator==(complex64_t a, complex64_t b) {
+
99 return a.real == b.real && a.imag == b.imag;
+
100}
+
+
101
+
+ +
103 return {a.real + b.real, a.imag + b.imag};
+
104}
+
+
105
+
+ +
107 return {a.real - b.real, a.imag - b.imag};
+
108}
+
+
109
+
+ +
111 return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
+
112}
+
+
113
+
+ +
115 auto denom = b.real * b.real + b.imag * b.imag;
+
116 auto x = a.real * b.real + a.imag * b.imag;
+
117 auto y = a.imag * b.real - a.real * b.imag;
+
118 return {x / denom, y / denom};
+
119}
+
+
120
+
+ +
122 auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
+
123 auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
+
124 if (real != 0 && (real < 0 != b.real < 0)) {
+
125 real += b.real;
+
126 }
+
127 if (imag != 0 && (imag < 0 != b.imag < 0)) {
+
128 imag += b.imag;
+
129 }
+
130 return {real, imag};
+
131}
+
+
constexpr bool operator>(complex64_t a, complex64_t b)
Definition complex.h:86
+
constexpr complex64_t operator-(complex64_t x)
Definition complex.h:78
+
static constexpr constant bool can_convert_to_complex64
Definition complex.h:12
+
constexpr bool operator<(complex64_t a, complex64_t b)
Definition complex.h:94
+
constexpr complex64_t operator*(complex64_t a, complex64_t b)
Definition complex.h:110
+
constexpr complex64_t operator%(complex64_t a, complex64_t b)
Definition complex.h:121
+
constexpr bool operator>=(complex64_t a, complex64_t b)
Definition complex.h:82
+
static constexpr constant bool can_convert_from_complex64
Definition complex.h:16
+
constexpr bool operator==(complex64_t a, complex64_t b)
Definition complex.h:98
+
constexpr complex64_t operator+(complex64_t a, complex64_t b)
Definition complex.h:102
+
constexpr complex64_t operator/(complex64_t a, complex64_t b)
Definition complex.h:114
+
constexpr bool operator<=(complex64_t a, complex64_t b)
Definition complex.h:90
+
Definition bf16.h:265
+
Definition complex.h:20
+
constexpr complex64_t(T x) const ant
Definition complex.h:46
+
constexpr complex64_t(T x) thread
Definition complex.h:31
+
constexpr complex64_t(T x) threadgroup
Definition complex.h:36
+
float imag
Definition complex.h:22
+
float real
Definition complex.h:21
+
constexpr complex64_t(T x) device
Definition complex.h:41
+
constexpr complex64_t(float real, float imag)
Definition complex.h:25
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h.html b/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h.html new file mode 100644 index 000000000..d991b9612 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h.html @@ -0,0 +1,116 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/reduction/ops.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes
+
ops.h File Reference
+
+
+
#include <metal_atomic>
+#include <metal_simdgroup>
+#include "mlx/backend/metal/kernels/atomic.h"
+#include "mlx/backend/metal/kernels/bf16.h"
+#include "mlx/backend/metal/kernels/utils.h"
+
+

Go to the source code of this file.

+ + + + + + + + + + + + + + + + + + +

+Classes

union  bool4_or_uint
 
struct  None
 
struct  And
 
struct  Or
 
struct  Sum< U >
 
struct  Prod< U >
 
struct  Min< U >
 
struct  Max< U >
 
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h_source.html b/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h_source.html new file mode 100644 index 000000000..58bc01ea0 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h_source.html @@ -0,0 +1,385 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/reduction/ops.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
ops.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_atomic>
+
6#include <metal_simdgroup>
+
7
+ + + +
11
+
+ +
13 bool4 b;
+
14 unsigned int i;
+
15};
+
+
16
+
+
17struct None {
+
18 template <typename T>
+
+
19 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
20 mlx_atomic_store_explicit(out, val, offset);
+
21 }
+
+
22};
+
+
23
+
+
24struct And {
+
+
25 bool simd_reduce(bool val) {
+
26 return simd_all(val);
+
27 };
+
+
28
+
29 static constexpr constant bool init = true;
+
30
+
+ +
32 device mlx_atomic<unsigned int>* out,
+
33 bool val,
+
34 int elem_idx,
+
35 int offset = 0) {
+
36 if (!val) {
+ +
38 update.b = {true, true, true, true};
+
39 update.b[elem_idx] = false;
+ +
41 }
+
42 }
+
+
43
+
+
44 void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
+
45 if (!val) {
+
46 mlx_atomic_store_explicit(out, val, offset);
+
47 }
+
48 }
+
+
49
+
50 // Non atomic update
+
+
51 void update(device bool* out, bool val) {
+
52 *out &= val;
+
53 }
+
+
54
+
55 // Operator
+
+
56 bool operator()(bool a, bool b) {
+
57 return a && b;
+
58 }
+
+
59};
+
+
60
+
+
61struct Or {
+
+
62 bool simd_reduce(bool val) {
+
63 return simd_any(val);
+
64 };
+
+
65
+
66 static constexpr constant bool init = false;
+
67
+
+ +
69 device mlx_atomic<unsigned int>* out,
+
70 bool val,
+
71 uint elem_idx,
+
72 uint offset = 0) {
+
73 if (val) {
+ +
75 update.b = {false, false, false, false};
+
76 update.b[elem_idx] = true;
+ +
78 }
+
79 }
+
+
80
+
+
81 void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
+
82 if (val) {
+
83 mlx_atomic_store_explicit(out, val, offset);
+
84 }
+
85 }
+
+
86
+
87 // Non atomic update
+
+
88 void update(device bool* out, bool val) {
+
89 *out |= val;
+
90 }
+
+
91
+
92 // Operator
+
+
93 bool operator()(bool a, bool b) {
+
94 return a || b;
+
95 }
+
+
96};
+
+
97
+
98template <typename U>
+
+
99struct Sum {
+
100 template <typename T>
+
+
101 T simd_reduce(T val) {
+
102 return simd_sum(val);
+
103 };
+
+
104
+
105 static constexpr constant U init = U(0);
+
106
+
107 template <typename T>
+
+
108 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
109 mlx_atomic_fetch_add_explicit(out, val, offset);
+
110 }
+
+
111
+
112 // Operator
+
+
113 U operator()(U a, U b) {
+
114 return a + b;
+
115 }
+
+
116};
+
+
117
+
118template <typename U>
+
+
119struct Prod {
+
120 template <typename T>
+
+
121 T simd_reduce(T val) {
+
122 return simd_product(val);
+
123 };
+
+
124
+
125 static constexpr constant U init = U(1);
+
126
+
127 template <typename T>
+
+
128 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
129 mlx_atomic_fetch_mul_explicit(out, val, offset);
+
130 }
+
+
131
+
132 // Operator
+
+
133 U operator()(U a, U b) {
+
134 return a * b;
+
135 }
+
+
136};
+
+
137
+
138template <typename U>
+
+
139struct Min {
+
140 template <typename T>
+
+
141 T simd_reduce(T val) {
+
142 return simd_min(val);
+
143 };
+
+
144
+
145 static constexpr constant U init = Limits<U>::max;
+
146
+
147 template <typename T>
+
+
148 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
149 mlx_atomic_fetch_min_explicit(out, val, offset);
+
150 }
+
+
151
+
152 // Operator
+
+
153 U operator()(U a, U b) {
+
154 return a < b ? a : b;
+
155 }
+
+
156};
+
+
157
+
158template <typename U>
+
+
159struct Max {
+
160 template <typename T>
+
+
161 T simd_reduce(T val) {
+
162 return simd_max(val);
+
163 };
+
+
164
+
165 static constexpr constant U init = Limits<U>::min;
+
166
+
167 template <typename T>
+
+
168 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
169 mlx_atomic_fetch_max_explicit(out, val, offset);
+
170 }
+
+
171
+
172 // Operator
+
+
173 U operator()(U a, U b) {
+
174 return a > b ? a : b;
+
175 }
+
+
176};
+
+ +
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:82
+
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:52
+
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:47
+
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:61
+
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:74
+
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:66
+
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:90
+ + +
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:392
+
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:392
+
METAL_FUNC bfloat16_t simd_product(bfloat16_t data)
Definition bf16_math.h:392
+
METAL_FUNC bfloat16_t simd_min(bfloat16_t data)
Definition bf16_math.h:392
+
Definition ops.h:24
+
void atomic_update(device mlx_atomic< bool > *out, bool val, uint offset=0)
Definition ops.h:44
+
bool operator()(bool a, bool b)
Definition ops.h:56
+
bool simd_reduce(bool val)
Definition ops.h:25
+
static constexpr constant bool init
Definition ops.h:29
+
void atomic_update(device mlx_atomic< unsigned int > *out, bool val, int elem_idx, int offset=0)
Definition ops.h:31
+
void update(device bool *out, bool val)
Definition ops.h:51
+
Definition utils.h:14
+
Definition ops.h:159
+
T simd_reduce(T val)
Definition ops.h:161
+
U operator()(U a, U b)
Definition ops.h:173
+
static constexpr constant U init
Definition ops.h:165
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:168
+
Definition ops.h:139
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:148
+
U operator()(U a, U b)
Definition ops.h:153
+
static constexpr constant U init
Definition ops.h:145
+
T simd_reduce(T val)
Definition ops.h:141
+
Definition ops.h:17
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:19
+
Definition ops.h:61
+
void atomic_update(device mlx_atomic< bool > *out, bool val, uint offset=0)
Definition ops.h:81
+
bool operator()(bool a, bool b)
Definition ops.h:93
+
void atomic_update(device mlx_atomic< unsigned int > *out, bool val, uint elem_idx, uint offset=0)
Definition ops.h:68
+
void update(device bool *out, bool val)
Definition ops.h:88
+
bool simd_reduce(bool val)
Definition ops.h:62
+
static constexpr constant bool init
Definition ops.h:66
+
Definition ops.h:119
+
U operator()(U a, U b)
Definition ops.h:133
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:128
+
T simd_reduce(T val)
Definition ops.h:121
+
static constexpr constant U init
Definition ops.h:125
+
Definition ops.h:99
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:108
+
static constexpr constant U init
Definition ops.h:105
+
T simd_reduce(T val)
Definition ops.h:101
+
U operator()(U a, U b)
Definition ops.h:113
+
Definition atomic.h:26
+
Definition ops.h:12
+
bool4 b
Definition ops.h:13
+
unsigned int i
Definition ops.h:14
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h.html b/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h.html new file mode 100644 index 000000000..2c79e0cea --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h.html @@ -0,0 +1,126 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/reduction/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Variables
+
utils.h File Reference
+
+
+
#include <metal_atomic>
+#include <metal_simdgroup>
+#include "mlx/backend/metal/kernels/defines.h"
+#include "mlx/backend/metal/kernels/steel/utils.h"
+#include "mlx/backend/metal/kernels/utils.h"
+#include "mlx/backend/metal/kernels/reduction/ops.h"
+
+

Go to the source code of this file.

+ + + + +

+Variables

static constant constexpr const uint8_t simd_size = 32
 
+

Variable Documentation

+ +

◆ simd_size

+ +
+
+ + + + + +
+ + + + +
constant constexpr const uint8_t simd_size = 32
+
+staticconstexpr
+
+ +
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h_source.html b/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h_source.html new file mode 100644 index 000000000..c20c0da98 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h_source.html @@ -0,0 +1,111 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/reduction/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_atomic>
+
6#include <metal_simdgroup>
+
7
+ + + +
11
+ +
13
+
14static constant constexpr const uint8_t simd_size = 32;
+ +
static constant constexpr const uint8_t simd_size
Definition utils.h:14
+ + + +
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h.html b/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h.html new file mode 100644 index 000000000..2f5a4eaa4 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h.html @@ -0,0 +1,114 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/gemm/transforms.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Namespaces
+
transforms.h File Reference
+
+
+
#include "mlx/backend/metal/kernels/steel/utils.h"
+
+

Go to the source code of this file.

+ + + + + + + + + + + + +

+Classes

struct  mlx::steel::TransformNone< OutT, InT >
 
struct  mlx::steel::TransformAdd< OutT, InT >
 
struct  mlx::steel::TransformAxpby< OutT, InT >
 
struct  mlx::steel::AccumHelper< T >
 
struct  mlx::steel::BlockSwizzle
 
+ + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::steel
 
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h_source.html b/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h_source.html new file mode 100644 index 000000000..f9ed6a5f8 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h_source.html @@ -0,0 +1,192 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/gemm/transforms.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
transforms.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+ +
6
+
8// Transforms and Epilogues
+
10
+
11namespace mlx {
+
12namespace steel {
+
13
+
14template <typename OutT, typename InT>
+
+ +
+
16 static METAL_FUNC OutT apply(InT x) {
+
17 return static_cast<OutT>(x);
+
18 }
+
+
19
+
+
20 static METAL_FUNC OutT apply(InT x, OutT) {
+
21 return static_cast<OutT>(x);
+
22 }
+
+
23};
+
+
24
+
25template <typename OutT, typename InT>
+
+ +
27 TransformAdd(const float, const float) {}
+
28
+
+
29 static METAL_FUNC OutT apply(InT x, OutT c) {
+
30 return static_cast<OutT>(x) + c;
+
31 }
+
+
32};
+
+
33
+
34template <typename OutT, typename InT>
+
+ +
36 const float alpha;
+
37 const float beta;
+
38
+
+
39 TransformAxpby(const float alpha_, const float beta_)
+
40 : alpha(alpha_), beta(beta_) {}
+
+
41
+
+
42 METAL_FUNC OutT apply(InT x, OutT c) const {
+
43 return static_cast<OutT>(x * alpha + (beta * c));
+
44 }
+
+
45};
+
+
46
+
47template <typename T>
+
+ +
49 typedef float accum_type;
+
50};
+
+
51
+
+ +
53 static METAL_FUNC int2
+
+
54 swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
+
55 const int tid_x = (tid.x) >> swizzle_log;
+
56 const int tid_y =
+
57 ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
+
58 return int2(tid_x, tid_y);
+
59 }
+
+
60};
+
+
61
+
62} // namespace steel
+
63} // namespace mlx
+ +
Definition allocator.h:7
+
Definition transforms.h:48
+
float accum_type
Definition transforms.h:49
+
Definition transforms.h:52
+
static METAL_FUNC int2 swizzle(uint3 tid, const int swizzle_log)
Definition transforms.h:54
+
Definition transforms.h:26
+
static METAL_FUNC OutT apply(InT x, OutT c)
Definition transforms.h:29
+
TransformAdd(const float, const float)
Definition transforms.h:27
+
Definition transforms.h:35
+
const float beta
Definition transforms.h:37
+
METAL_FUNC OutT apply(InT x, OutT c) const
Definition transforms.h:42
+
const float alpha
Definition transforms.h:36
+
TransformAxpby(const float alpha_, const float beta_)
Definition transforms.h:39
+
Definition transforms.h:15
+
static METAL_FUNC OutT apply(InT x)
Definition transforms.h:16
+
static METAL_FUNC OutT apply(InT x, OutT)
Definition transforms.h:20
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h.html b/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h.html new file mode 100644 index 000000000..31eb63250 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h.html @@ -0,0 +1,215 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Macros | +Functions
+
utils.h File Reference
+
+
+
#include <metal_stdlib>
+
+

Go to the source code of this file.

+ + + + + + +

+Macros

#define STEEL_CONST   static constant constexpr const
 
#define STEEL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
 
+ + + + + +

+Functions

METAL_FUNC ulong2 elem_to_loc_broadcast (uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
 
METAL_FUNC ulong3 elem_to_loc_broadcast (uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
 
+

Macro Definition Documentation

+ +

◆ STEEL_CONST

+ +
+
+ + + + +
#define STEEL_CONST   static constant constexpr const
+
+ +
+
+ +

◆ STEEL_PRAGMA_UNROLL

+ +
+
+ + + + +
#define STEEL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
+
+ +
+
+

Function Documentation

+ +

◆ elem_to_loc_broadcast() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC ulong3 elem_to_loc_broadcast (uint elem,
constant const int * shape,
constant const size_t * a_strides,
constant const size_t * b_strides,
constant const size_t * c_strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc_broadcast() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC ulong2 elem_to_loc_broadcast (uint elem,
constant const int * shape,
constant const size_t * a_strides,
constant const size_t * b_strides,
int ndim )
+
+ +
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h_source.html b/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h_source.html new file mode 100644 index 000000000..1621ac55e --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h_source.html @@ -0,0 +1,142 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_stdlib>
+
6
+
7#define STEEL_CONST static constant constexpr const
+
8#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
+
9
+
+
10METAL_FUNC ulong2 elem_to_loc_broadcast(
+
11 uint elem,
+
12 constant const int* shape,
+
13 constant const size_t* a_strides,
+
14 constant const size_t* b_strides,
+
15 int ndim) {
+
16 ulong loc_a{0};
+
17 ulong loc_b{0};
+
18 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
19 int pos_in_dim = (elem % shape[i]);
+
20 elem /= shape[i];
+
21 loc_a += pos_in_dim * a_strides[i];
+
22 loc_b += pos_in_dim * b_strides[i];
+
23 }
+
24 return ulong2(loc_a, loc_b);
+
25}
+
+
26
+
+
27METAL_FUNC ulong3 elem_to_loc_broadcast(
+
28 uint elem,
+
29 constant const int* shape,
+
30 constant const size_t* a_strides,
+
31 constant const size_t* b_strides,
+
32 constant const size_t* c_strides,
+
33 int ndim) {
+
34 ulong loc_a{0};
+
35 ulong loc_b{0};
+
36 ulong loc_c{0};
+
37 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
38 int pos_in_dim = (elem % shape[i]);
+
39 elem /= shape[i];
+
40 loc_a += pos_in_dim * a_strides[i];
+
41 loc_b += pos_in_dim * b_strides[i];
+
42 loc_c += pos_in_dim * c_strides[i];
+
43 }
+
44 return ulong3(loc_a, loc_b, loc_c);
+
45}
+
+
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:10
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2utils_8h.html b/docs/build/html/backend_2metal_2kernels_2utils_8h.html new file mode 100644 index 000000000..bce4b00aa --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2utils_8h.html @@ -0,0 +1,862 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Macros | +Functions
+
utils.h File Reference
+
+
+
#include <metal_math>
+#include "mlx/backend/metal/kernels/bf16.h"
+#include "mlx/backend/metal/kernels/complex.h"
+
+

Go to the source code of this file.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Classes

struct  Limits< U >
 
struct  Limits< uint8_t >
 
struct  Limits< uint16_t >
 
struct  Limits< uint32_t >
 
struct  Limits< uint64_t >
 
struct  Limits< int8_t >
 
struct  Limits< int16_t >
 
struct  Limits< int32_t >
 
struct  Limits< int64_t >
 
struct  Limits< half >
 
struct  Limits< float >
 
struct  Limits< bfloat16_t >
 
struct  Limits< bool >
 
+ + + + + + + +

+Macros

#define instantiate_default_limit(type)
 
#define instantiate_float_limit(type)
 
#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

template<typename stride_t >
METAL_FUNC stride_t elem_to_loc (uint elem, device const int *shape, device const stride_t *strides, int ndim)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc (uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc (uint3 elem, constant const int *shape, constant const stride_t *strides, int ndim)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_1 (uint elem, constant const stride_t &stride)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_2 (uint2 elem, constant const stride_t strides[2])
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_3 (uint3 elem, constant const stride_t strides[3])
 
template<int NDIM>
METAL_FUNC size_t elem_to_loc_nd (uint elem, device const int *shape, device const size_t *strides)
 
template<int NDIM>
METAL_FUNC size_t elem_to_loc_nd (uint3 elem, constant const int shape[NDIM], constant const size_t strides[NDIM])
 
template<int NDIM>
METAL_FUNC int64_t elem_to_loc_nd (uint elem, constant const int shape[NDIM], constant const int64_t strides[NDIM])
 
template<int NDIM>
METAL_FUNC int64_t elem_to_loc_nd (uint3 elem, constant const int shape[NDIM], constant const int64_t strides[NDIM])
 
METAL_FUNC uint2 elem_to_loc_2_nd (uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
 
METAL_FUNC uint3 elem_to_loc_3_nd (uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
 
template<int NDIM>
METAL_FUNC uint2 elem_to_loc_2_nd (uint3 elem, constant const int shape[NDIM], constant const size_t a_strides[NDIM], constant const size_t b_strides[NDIM])
 
template<int NDIM>
METAL_FUNC uint3 elem_to_loc_3_nd (uint3 elem, constant const int shape[NDIM], constant const size_t a_strides[NDIM], constant const size_t b_strides[NDIM], constant const size_t c_strides[NDIM])
 
size_t ceildiv (size_t N, size_t M)
 Compute ceil((float)N/(float)M)
 
float log1p (float x)
 
bfloat16_t log1p (bfloat16_t x)
 
uint64_t simd_shuffle_down (uint64_t data, uint16_t delta)
 
int64_t simd_shuffle_down (int64_t data, uint16_t delta)
 
bool simd_shuffle_down (bool data, uint16_t delta)
 
+

Macro Definition Documentation

+ +

◆ instantiate_default_limit

+ +
+
+ + + + + + + +
#define instantiate_default_limit( type)
+
+Value:
template <> \
+
struct Limits<type> { \
+
static constexpr constant type max = metal::numeric_limits<type>::max(); \
+
static constexpr constant type min = metal::numeric_limits<type>::min(); \
+
static constexpr constant type finite_max = \
+
metal::numeric_limits<type>::max(); \
+
static constexpr constant type finite_min = \
+
metal::numeric_limits<type>::min(); \
+
};
+
Definition utils.h:14
+
static const constant U max
Definition utils.h:15
+
static const constant U finite_max
Definition utils.h:17
+
static const constant U min
Definition utils.h:16
+
static const constant U finite_min
Definition utils.h:18
+
+
+
+ +

◆ instantiate_float_limit

+ +
+
+ + + + + + + +
#define instantiate_float_limit( type)
+
+Value:
template <> \
+
struct Limits<type> { \
+
static constexpr constant type max = \
+
metal::numeric_limits<type>::infinity(); \
+
static constexpr constant type min = \
+
-metal::numeric_limits<type>::infinity(); \
+
static constexpr constant type finite_max = \
+
metal::numeric_limits<type>::max(); \
+
static constexpr constant type finite_min = \
+
-metal::numeric_limits<type>::max(); \
+
};
+
+
+
+ +

◆ MLX_MTL_PRAGMA_UNROLL

+ +
+
+ + + + +
#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
+
+ +
+
+

Function Documentation

+ +

◆ ceildiv()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
size_t ceildiv (size_t N,
size_t M )
+
+inline
+
+ +

Compute ceil((float)N/(float)M)

+ +
+
+ +

◆ elem_to_loc() [1/3]

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc (uint elem,
constant const int * shape,
constant const stride_t * strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc() [2/3]

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc (uint elem,
device const int * shape,
device const stride_t * strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc() [3/3]

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc (uint3 elem,
constant const int * shape,
constant const stride_t * strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc_1()

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc_1 (uint elem,
constant const stride_t & stride )
+
+ +
+
+ +

◆ elem_to_loc_2()

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc_2 (uint2 elem,
constant const stride_t strides[2] )
+
+ +
+
+ +

◆ elem_to_loc_2_nd() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC uint2 elem_to_loc_2_nd (uint3 elem,
constant const int * shape,
constant const size_t * a_strides,
constant const size_t * b_strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc_2_nd() [2/2]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC uint2 elem_to_loc_2_nd (uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM] )
+
+ +
+
+ +

◆ elem_to_loc_3()

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc_3 (uint3 elem,
constant const stride_t strides[3] )
+
+ +
+
+ +

◆ elem_to_loc_3_nd() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC uint3 elem_to_loc_3_nd (uint3 elem,
constant const int * shape,
constant const size_t * a_strides,
constant const size_t * b_strides,
constant const size_t * c_strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc_3_nd() [2/2]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC uint3 elem_to_loc_3_nd (uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM],
constant const size_t c_strides[NDIM] )
+
+ +
+
+ +

◆ elem_to_loc_nd() [1/4]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + +
METAL_FUNC int64_t elem_to_loc_nd (uint elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM] )
+
+ +
+
+ +

◆ elem_to_loc_nd() [2/4]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + +
METAL_FUNC size_t elem_to_loc_nd (uint elem,
device const int * shape,
device const size_t * strides )
+
+ +
+
+ +

◆ elem_to_loc_nd() [3/4]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + +
METAL_FUNC int64_t elem_to_loc_nd (uint3 elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM] )
+
+ +
+
+ +

◆ elem_to_loc_nd() [4/4]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + +
METAL_FUNC size_t elem_to_loc_nd (uint3 elem,
constant const int shape[NDIM],
constant const size_t strides[NDIM] )
+
+ +
+
+ +

◆ log1p() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
bfloat16_t log1p (bfloat16_t x)
+
+inline
+
+ +
+
+ +

◆ log1p() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
float log1p (float x)
+
+inline
+
+ +
+
+ +

◆ simd_shuffle_down() [1/3]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
bool simd_shuffle_down (bool data,
uint16_t delta )
+
+inline
+
+ +
+
+ +

◆ simd_shuffle_down() [2/3]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
int64_t simd_shuffle_down (int64_t data,
uint16_t delta )
+
+inline
+
+ +
+
+ +

◆ simd_shuffle_down() [3/3]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
uint64_t simd_shuffle_down (uint64_t data,
uint16_t delta )
+
+inline
+
+ +
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html b/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html new file mode 100644 index 000000000..32aeb0750 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html @@ -0,0 +1,488 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_math>
+ + +
8
+
10// Type limits utils
+
12
+
13template <typename U>
+
+
14struct Limits {
+
15 static const constant U max = metal::numeric_limits<U>::max();
+
16 static const constant U min = metal::numeric_limits<U>::min();
+
17 static const constant U finite_max = metal::numeric_limits<U>::max();
+
18 static const constant U finite_min = metal::numeric_limits<U>::min();
+
19};
+
+
20
+
+
21#define instantiate_default_limit(type) \
+
22 template <> \
+
23 struct Limits<type> { \
+
24 static constexpr constant type max = metal::numeric_limits<type>::max(); \
+
25 static constexpr constant type min = metal::numeric_limits<type>::min(); \
+
26 static constexpr constant type finite_max = \
+
27 metal::numeric_limits<type>::max(); \
+
28 static constexpr constant type finite_min = \
+
29 metal::numeric_limits<type>::min(); \
+
30 };
+
+
31
+ + + + + + + + +
40
+
+
41#define instantiate_float_limit(type) \
+
42 template <> \
+
43 struct Limits<type> { \
+
44 static constexpr constant type max = \
+
45 metal::numeric_limits<type>::infinity(); \
+
46 static constexpr constant type min = \
+
47 -metal::numeric_limits<type>::infinity(); \
+
48 static constexpr constant type finite_max = \
+
49 metal::numeric_limits<type>::max(); \
+
50 static constexpr constant type finite_min = \
+
51 -metal::numeric_limits<type>::max(); \
+
52 };
+
+
53
+ + + +
57
+
58template <>
+
+
59struct Limits<bool> {
+
60 static constexpr constant bool max = true;
+
61 static constexpr constant bool min = false;
+
62};
+
+
63
+
65// Indexing utils
+
67
+
68#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
+
69
+
71// Single Array with generic dims
+
72
+
73template <typename stride_t>
+
+
74METAL_FUNC stride_t elem_to_loc(
+
75 uint elem,
+
76 device const int* shape,
+
77 device const stride_t* strides,
+
78 int ndim) {
+
79 stride_t loc = 0;
+
80 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
81 loc += (elem % shape[i]) * strides[i];
+
82 elem /= shape[i];
+
83 }
+
84 return loc;
+
85}
+
+
86
+
87template <typename stride_t>
+
+
88METAL_FUNC stride_t elem_to_loc(
+
89 uint elem,
+
90 constant const int* shape,
+
91 constant const stride_t* strides,
+
92 int ndim) {
+
93 stride_t loc = 0;
+
94 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
95 loc += (elem % shape[i]) * strides[i];
+
96 elem /= shape[i];
+
97 }
+
98 return loc;
+
99}
+
+
100
+
101// Non templated version to handle arbitrary dims
+
102template <typename stride_t>
+
+
103METAL_FUNC stride_t elem_to_loc(
+
104 uint3 elem,
+
105 constant const int* shape,
+
106 constant const stride_t* strides,
+
107 int ndim) {
+
108 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
+
109 for (int d = ndim - 3; d >= 0; --d) {
+
110 loc += (elem.z % shape[d]) * strides[d];
+
111 elem.z /= shape[d];
+
112 }
+
113 return loc;
+
114}
+
+
115
+
117// Single Array with fixed N dims
+
118
+
119template <typename stride_t>
+
+
120METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
+
121 return elem * stride;
+
122}
+
+
123
+
124template <typename stride_t>
+
125METAL_FUNC stride_t
+
+
126elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
+
127 return elem.x * strides[1] + elem.y * strides[0];
+
128}
+
+
129
+
130template <typename stride_t>
+
131METAL_FUNC stride_t
+
+
132elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
+
133 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
+
134}
+
+
135
+
136template <int NDIM>
+
+
137METAL_FUNC size_t elem_to_loc_nd(
+
138 uint elem,
+
139 device const int* shape,
+
140 device const size_t* strides) {
+
141 size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
+
142
+ +
144 for (int d = NDIM - 2; d >= 0; --d) {
+
145 elem /= shape[d + 1];
+
146 loc += (elem % shape[d]) * strides[d];
+
147 }
+
148
+
149 return loc;
+
150}
+
+
151
+
152template <int NDIM>
+
+
153METAL_FUNC size_t elem_to_loc_nd(
+
154 uint3 elem,
+
155 constant const int shape[NDIM],
+
156 constant const size_t strides[NDIM]) {
+
157 size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
+
158 for (int d = NDIM - 3; d >= 0; --d) {
+
159 loc += (elem.z % shape[d]) * strides[d];
+
160 elem.z /= shape[d];
+
161 }
+
162 return loc;
+
163}
+
+
164
+
165template <int NDIM>
+
+
166METAL_FUNC int64_t elem_to_loc_nd(
+
167 uint elem,
+
168 constant const int shape[NDIM],
+
169 constant const int64_t strides[NDIM]) {
+
170 int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
+
171
+ +
173 for (int d = NDIM - 2; d >= 0; --d) {
+
174 elem /= shape[d + 1];
+
175 loc += (elem % shape[d]) * strides[d];
+
176 }
+
177
+
178 return loc;
+
179}
+
+
180
+
181template <int NDIM>
+
+
182METAL_FUNC int64_t elem_to_loc_nd(
+
183 uint3 elem,
+
184 constant const int shape[NDIM],
+
185 constant const int64_t strides[NDIM]) {
+
186 int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
+
187 for (int d = NDIM - 3; d >= 0; --d) {
+
188 loc += (elem.z % shape[d]) * strides[d];
+
189 elem.z /= shape[d];
+
190 }
+
191 return loc;
+
192}
+
+
193
+
195// Multiple Arrays with generic dims
+
196
+
+
197METAL_FUNC uint2 elem_to_loc_2_nd(
+
198 uint3 elem,
+
199 constant const int* shape,
+
200 constant const size_t* a_strides,
+
201 constant const size_t* b_strides,
+
202 int ndim) {
+
203 uint2 loc = {
+
204 static_cast<uint>(
+
205 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
+
206 static_cast<uint>(
+
207 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
+
208 for (int d = ndim - 3; d >= 0; --d) {
+
209 uint l = elem.z % shape[d];
+
210 loc.x += l * a_strides[d];
+
211 loc.y += l * b_strides[d];
+
212 elem.z /= shape[d];
+
213 }
+
214 return loc;
+
215}
+
+
216
+
+
217METAL_FUNC uint3 elem_to_loc_3_nd(
+
218 uint3 elem,
+
219 constant const int* shape,
+
220 constant const size_t* a_strides,
+
221 constant const size_t* b_strides,
+
222 constant const size_t* c_strides,
+
223 int ndim) {
+
224 uint3 loc = {
+
225 static_cast<uint>(
+
226 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
+
227 static_cast<uint>(
+
228 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
+
229 static_cast<uint>(
+
230 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
+
231 for (int d = ndim - 3; d >= 0; --d) {
+
232 uint l = elem.z % shape[d];
+
233 loc.x += l * a_strides[d];
+
234 loc.y += l * b_strides[d];
+
235 loc.z += l * c_strides[d];
+
236 elem.z /= shape[d];
+
237 }
+
238 return loc;
+
239}
+
+
240
+
242// Multiple Arrays with fixed N dims
+
243
+
244template <int NDIM>
+
+
245METAL_FUNC uint2 elem_to_loc_2_nd(
+
246 uint3 elem,
+
247 constant const int shape[NDIM],
+
248 constant const size_t a_strides[NDIM],
+
249 constant const size_t b_strides[NDIM]) {
+
250 uint2 loc = {
+
251 static_cast<uint>(
+
252 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
+
253 static_cast<uint>(
+
254 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
+
255 for (int d = NDIM - 3; d >= 0; --d) {
+
256 uint l = elem.z % shape[d];
+
257 loc.x += l * a_strides[d];
+
258 loc.y += l * b_strides[d];
+
259 elem.z /= shape[d];
+
260 }
+
261 return loc;
+
262}
+
+
263
+
264template <int NDIM>
+
+
265METAL_FUNC uint3 elem_to_loc_3_nd(
+
266 uint3 elem,
+
267 constant const int shape[NDIM],
+
268 constant const size_t a_strides[NDIM],
+
269 constant const size_t b_strides[NDIM],
+
270 constant const size_t c_strides[NDIM]) {
+
271 uint3 loc = {
+
272 static_cast<uint>(
+
273 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
+
274 static_cast<uint>(
+
275 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
+
276 static_cast<uint>(
+
277 elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
+
278 for (int d = NDIM - 3; d >= 0; --d) {
+
279 uint l = elem.z % shape[d];
+
280 loc.x += l * a_strides[d];
+
281 loc.y += l * b_strides[d];
+
282 loc.z += l * c_strides[d];
+
283 elem.z /= shape[d];
+
284 }
+
285 return loc;
+
286}
+
+
287
+
289// Calculation utils
+
291
+
+
293inline size_t ceildiv(size_t N, size_t M) {
+
294 return (N + M - 1) / M;
+
295}
+
+
296
+
297// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
+
+
298inline float log1p(float x) {
+
299 float xp1 = 1.0f + x;
+
300 if (xp1 == Limits<float>::max) {
+
301 return Limits<float>::max;
+
302 }
+
303 if (xp1 == 1.0f) {
+
304 return x;
+
305 }
+
306
+
307 return x * (metal::log(xp1) / (xp1 - 1.0f));
+
308}
+
+
309
+
+ +
311 float xp1 = 1.0f + static_cast<float>(x);
+
312 if (xp1 == Limits<float>::max) {
+ +
314 }
+
315 if (xp1 == 1.0f) {
+
316 return x;
+
317 }
+
318
+
319 return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
+
320}
+
+
321
+
323// SIMD shuffle ops
+
325
+
+
326inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
+
327 return as_type<uint64_t>(
+
328 metal::simd_shuffle_down(as_type<uint2>(data), delta));
+
329}
+
+
330
+
+
331inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
+
332 return as_type<int64_t>(
+
333 metal::simd_shuffle_down(as_type<uint2>(data), delta));
+
334}
+
+
335
+
+
336inline bool simd_shuffle_down(bool data, uint16_t delta) {
+
337 return simd_shuffle_down(static_cast<uint32_t>(data), delta);
+
338}
+
+ +
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:257
+ +
#define MLX_MTL_PRAGMA_UNROLL
Definition utils.h:68
+
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:120
+
#define instantiate_float_limit(type)
Definition utils.h:41
+
float log1p(float x)
Definition utils.h:298
+
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:132
+
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:74
+
METAL_FUNC uint2 elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:197
+
size_t ceildiv(size_t N, size_t M)
Compute ceil((float)N/(float)M)
Definition utils.h:293
+
METAL_FUNC uint3 elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
Definition utils.h:217
+
METAL_FUNC size_t elem_to_loc_nd(uint elem, device const int *shape, device const size_t *strides)
Definition utils.h:137
+
#define instantiate_default_limit(type)
Definition utils.h:21
+
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:126
+
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:234
+
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
+
Definition bf16.h:54
+
Definition utils.h:14
+
static const constant U max
Definition utils.h:15
+
static const constant U finite_max
Definition utils.h:17
+
static const constant U min
Definition utils.h:16
+
static const constant U finite_min
Definition utils.h:18
+
+ + + + diff --git a/docs/build/html/backend_2metal_2utils_8h.html b/docs/build/html/backend_2metal_2utils_8h.html new file mode 100644 index 000000000..9a2da99d3 --- /dev/null +++ b/docs/build/html/backend_2metal_2utils_8h.html @@ -0,0 +1,102 @@ + + + + + + + +MLX: mlx/backend/metal/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces
+
utils.h File Reference
+
+
+
#include "mlx/array.h"
+#include "mlx/backend/metal/device.h"
+#include "mlx/primitives.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+
+ + + + diff --git a/docs/build/html/backend_2metal_2utils_8h_source.html b/docs/build/html/backend_2metal_2utils_8h_source.html new file mode 100644 index 000000000..5bfd2bb38 --- /dev/null +++ b/docs/build/html/backend_2metal_2utils_8h_source.html @@ -0,0 +1,249 @@ + + + + + + + +MLX: mlx/backend/metal/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/array.h"
+ +
7#include "mlx/primitives.h"
+
8
+
9namespace mlx::core {
+
10
+
11namespace {
+
12
+
13using metal::CommandEncoder;
+
14
+
15template <typename T>
+
16inline void set_vector_bytes(
+
17 CommandEncoder& enc,
+
18 const std::vector<T>& vec,
+
19 size_t nelems,
+
20 int idx) {
+
21 enc->setBytes(vec.data(), nelems * sizeof(T), idx);
+
22}
+
23
+
24template <typename T>
+
25inline void
+
26set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
+
27 return set_vector_bytes(enc, vec, vec.size(), idx);
+
28}
+
29
+
30std::string type_to_name(const array& a) {
+
31 std::string tname;
+
32 switch (a.dtype()) {
+
33 case bool_:
+
34 tname = "bool_";
+
35 break;
+
36 case uint8:
+
37 tname = "uint8";
+
38 break;
+
39 case uint16:
+
40 tname = "uint16";
+
41 break;
+
42 case uint32:
+
43 tname = "uint32";
+
44 break;
+
45 case uint64:
+
46 tname = "uint64";
+
47 break;
+
48 case int8:
+
49 tname = "int8";
+
50 break;
+
51 case int16:
+
52 tname = "int16";
+
53 break;
+
54 case int32:
+
55 tname = "int32";
+
56 break;
+
57 case int64:
+
58 tname = "int64";
+
59 break;
+
60 case float16:
+
61 tname = "float16";
+
62 break;
+
63 case float32:
+
64 tname = "float32";
+
65 break;
+
66 case bfloat16:
+
67 tname = "bfloat16";
+
68 break;
+
69 case complex64:
+
70 tname = "complex64";
+
71 break;
+
72 }
+
73 return tname;
+
74}
+
75
+
76MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
+
77 int pows[3] = {0, 0, 0};
+
78 int sum = 0;
+
79 while (true) {
+
80 int presum = sum;
+
81 // Check all the pows
+
82 if (dim0 >= (1 << (pows[0] + 1))) {
+
83 pows[0]++;
+
84 sum++;
+
85 }
+
86 if (sum == 10) {
+
87 break;
+
88 }
+
89 if (dim1 >= (1 << (pows[1] + 1))) {
+
90 pows[1]++;
+
91 sum++;
+
92 }
+
93 if (sum == 10) {
+
94 break;
+
95 }
+
96 if (dim2 >= (1 << (pows[2] + 1))) {
+
97 pows[2]++;
+
98 sum++;
+
99 }
+
100 if (sum == presum || sum == 10) {
+
101 break;
+
102 }
+
103 }
+
104 return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
+
105}
+
106
+
107inline NS::String* make_string(std::ostringstream& os) {
+
108 std::string string = os.str();
+
109 return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
+
110}
+
111
+
112inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
+
113#ifdef MLX_METAL_DEBUG
+
114 std::ostringstream label;
+
115 label << "Stream " << index;
+
116 queue->setLabel(make_string(label));
+
117#endif
+
118}
+
119
+
120inline void debug_set_primitive_buffer_label(
+
121 MTL::CommandBuffer* command_buffer,
+
122 Primitive& primitive) {
+
123#ifdef MLX_METAL_DEBUG
+
124 std::ostringstream label;
+
125 if (auto cbuf_label = command_buffer->label(); cbuf_label) {
+
126 label << cbuf_label->utf8String();
+
127 }
+
128 primitive.print(label);
+
129 command_buffer->setLabel(make_string(label));
+
130#endif
+
131}
+
132
+
133bool is_power_of_2(int n) {
+
134 return ((n & (n - 1)) == 0) && n != 0;
+
135}
+
136
+
137} // namespace
+
138
+
139} // namespace mlx::core
+ + +
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+ +
+ + + + diff --git a/docs/build/html/bc_s.png b/docs/build/html/bc_s.png new file mode 100644 index 000000000..224b29aa9 Binary files /dev/null and b/docs/build/html/bc_s.png differ diff --git a/docs/build/html/bc_sd.png b/docs/build/html/bc_sd.png new file mode 100644 index 000000000..31ca888dc Binary files /dev/null and b/docs/build/html/bc_sd.png differ diff --git a/docs/build/html/bf16__math_8h.html b/docs/build/html/bf16__math_8h.html new file mode 100644 index 000000000..78e4d59e8 --- /dev/null +++ b/docs/build/html/bf16__math_8h.html @@ -0,0 +1,594 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/bf16_math.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces | +Macros | +Functions
+
bf16_math.h File Reference
+
+
+
#include "mlx/backend/metal/kernels/bf16.h"
+
+

Go to the source code of this file.

+ + + + + + + + +

+Namespaces

namespace  metal
 
namespace  metal::fast
 
namespace  metal::precise
 
+ + + + + + + + + + + +

+Macros

#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
 
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
 
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
 
#define bfloat16_to_uint16(x)   x.bits_
 
#define uint16_to_bfloat16(x)   _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

METAL_FUNC bfloat16_t metal::abs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::acos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::acosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::asin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::asinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::atan (bfloat16_t y_over_x)
 
METAL_FUNC bfloat16_t metal::atan2 (bfloat16_t y, bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::atanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::ceil (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::cos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::cosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::cospi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::divide (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::exp (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::exp10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::exp2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fabs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fdim (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::floor (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmax (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmin (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmod (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fract (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::frexp (bfloat16_t x, thread int &exp)
 
METAL_FUNC bfloat16_t metal::ldexp (bfloat16_t x, int k)
 
METAL_FUNC bfloat16_t metal::log (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::log10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::log2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::max (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::min (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::nextafter (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::pow (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::powr (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::rint (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::round (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::rsqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sinpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::tan (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::tanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::tanpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::trunc (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::abs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::acos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::acosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::asin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::asinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::atan (bfloat16_t y_over_x)
 
METAL_FUNC bfloat16_t metal::fast::atan2 (bfloat16_t y, bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::atanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::ceil (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::cos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::cosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::cospi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::divide (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::exp (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::exp10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::exp2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::fabs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::fdim (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::floor (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmax (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmin (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmod (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::fract (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::frexp (bfloat16_t x, thread int &exp)
 
METAL_FUNC bfloat16_t metal::fast::ldexp (bfloat16_t x, int k)
 
METAL_FUNC bfloat16_t metal::fast::log (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::log10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::log2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::max (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::min (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::nextafter (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::pow (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::powr (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::rint (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::round (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::rsqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sinpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::tan (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::tanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::tanpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::trunc (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::abs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::acos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::acosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::asin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::asinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::atan (bfloat16_t y_over_x)
 
METAL_FUNC bfloat16_t metal::precise::atan2 (bfloat16_t y, bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::atanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::ceil (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::cos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::cosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::cospi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::divide (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::exp (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::exp10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::exp2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::fabs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::fdim (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::floor (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmax (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmin (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmod (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::fract (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::frexp (bfloat16_t x, thread int &exp)
 
METAL_FUNC bfloat16_t metal::precise::ldexp (bfloat16_t x, int k)
 
METAL_FUNC bfloat16_t metal::precise::log (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::log10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::log2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::max (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::min (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::nextafter (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::pow (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::powr (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::rint (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::round (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::rsqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sinpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::tan (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::tanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::tanpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::trunc (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::simd_broadcast (bfloat16_t data, ushort broadcast_lane_id)
 
METAL_FUNC bfloat16_t metal::simd_shuffle (bfloat16_t data, ushort simd_lane_id)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_down (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_down (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_up (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_up (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_xor (bfloat16_t data, ushort mask)
 
METAL_FUNC bfloat16_t metal::simd_max (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_min (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_product (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_sum (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_product (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_sum (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_product (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_sum (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_xor (bfloat16_t data)
 
+

Macro Definition Documentation

+ +

◆ bfloat16_to_uint16

+ +
+
+ + + + + + + +
#define bfloat16_to_uint16( x)   x.bits_
+
+ +
+
+ +

◆ instantiate_metal_math_funcs

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
#define instantiate_metal_math_funcs( itype,
otype,
ctype,
mfast )
+
+ +
+
+ +

◆ instantiate_metal_simd_comm_funcs

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
#define instantiate_metal_simd_comm_funcs( itype,
otype,
ctype,
itype_to_ctype,
ctype_to_otype )
+
+ +
+
+ +

◆ instantiate_metal_simd_reduction_funcs

+ +
+
+ + + + + + + + + + + + + + + + +
#define instantiate_metal_simd_reduction_funcs( itype,
otype,
ctype )
+
+ +
+
+ +

◆ uint16_to_bfloat16

+ +
+
+ + + + + + + +
#define uint16_to_bfloat16( x)   _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
+
+ +
+
+
+ + + + diff --git a/docs/build/html/bf16__math_8h_source.html b/docs/build/html/bf16__math_8h_source.html new file mode 100644 index 000000000..217403ab0 --- /dev/null +++ b/docs/build/html/bf16__math_8h_source.html @@ -0,0 +1,498 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/bf16_math.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
bf16_math.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+ +
6
+
8// Metal math for bfloat16
+
10
+
11/*
+
12
+
13Following the Metal Shading Language Specification (Metal 3.1)
+
14
+
15"bfloat is an extended itypeing point type that only allows implicit conversion
+
16 to a type of greater itypeing point rank. While bfloat can be implicitly
+
17 converted to itype, it cannot be implicitly converted to half, and neither
+
18 itype nor half can be implicitly converted to bfloat."
+
19
+
20Further, as far as I can tell, the stdlib math/simd functions are not defined
+
21for bfloat and calling with an argument of type bfloat will result in that
+
22argument getting implicitly converted to itype which then returns an output
+
23that is (likely) a itype which cannot be implicitly converted into a bfloat
+
24
+
25This leads to situations where
+
26bfloat a = 5.0bf;
+
27bfloat b = metal::abs(a); // this will throw an error since abs return itype
+
28bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
+
29
+
30For the moment, I will be adding overloaded instantiations of the math
+
31functions to accordingly automatically handle the casting
+
32
+
33*/
+
34
+
+
35#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
+
36 \
+
37 METAL_FUNC otype abs(itype x) { \
+
38 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
+
39 } \
+
40 METAL_FUNC otype acos(itype x) { \
+
41 return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
+
42 } \
+
43 METAL_FUNC otype acosh(itype x) { \
+
44 return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
+
45 } \
+
46 METAL_FUNC otype asin(itype x) { \
+
47 return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
+
48 } \
+
49 METAL_FUNC otype asinh(itype x) { \
+
50 return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
+
51 } \
+
52 METAL_FUNC otype atan(itype y_over_x) { \
+
53 return static_cast<otype>( \
+
54 __metal_atan(static_cast<ctype>(y_over_x), mfast)); \
+
55 } \
+
56 METAL_FUNC otype atan2(itype y, itype x) { \
+
57 return static_cast<otype>( \
+
58 __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
+
59 } \
+
60 METAL_FUNC otype atanh(itype x) { \
+
61 return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
+
62 } \
+
63 METAL_FUNC otype ceil(itype x) { \
+
64 return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
+
65 } \
+
66 METAL_FUNC otype cos(itype x) { \
+
67 return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
+
68 } \
+
69 METAL_FUNC otype cosh(itype x) { \
+
70 return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
+
71 } \
+
72 METAL_FUNC otype cospi(itype x) { \
+
73 return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
+
74 } \
+
75 METAL_FUNC otype divide(itype x, itype y) { \
+
76 return static_cast<otype>( \
+
77 __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
78 } \
+
79 METAL_FUNC otype exp(itype x) { \
+
80 return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
+
81 } \
+
82 METAL_FUNC otype exp10(itype x) { \
+
83 return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
+
84 } \
+
85 METAL_FUNC otype exp2(itype x) { \
+
86 return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
+
87 } \
+
88 METAL_FUNC otype fabs(itype x) { \
+
89 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
+
90 } \
+
91 METAL_FUNC otype fdim(itype x, itype y) { \
+
92 ctype t = static_cast<ctype>(x - y); \
+
93 return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
+
94 } \
+
95 METAL_FUNC otype floor(itype x) { \
+
96 return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
+
97 } \
+
98 METAL_FUNC otype fma(itype x, itype y, itype z) { \
+
99 return static_cast<otype>(__metal_fma( \
+
100 static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
+
101 } \
+
102 METAL_FUNC otype fmax(itype x, itype y) { \
+
103 return static_cast<otype>( \
+
104 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
105 } \
+
106 METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
+
107 return static_cast<otype>(__metal_fmax3( \
+
108 static_cast<ctype>(x), \
+
109 static_cast<ctype>(y), \
+
110 static_cast<ctype>(z), \
+
111 mfast)); \
+
112 } \
+
113 METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
+
114 return static_cast<otype>(__metal_fmedian3( \
+
115 static_cast<ctype>(x), \
+
116 static_cast<ctype>(y), \
+
117 static_cast<ctype>(z), \
+
118 mfast)); \
+
119 } \
+
120 METAL_FUNC otype fmin(itype x, itype y) { \
+
121 return static_cast<otype>( \
+
122 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
123 } \
+
124 METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
+
125 return static_cast<otype>(__metal_fmin3( \
+
126 static_cast<ctype>(x), \
+
127 static_cast<ctype>(y), \
+
128 static_cast<ctype>(z), \
+
129 mfast)); \
+
130 } \
+
131 METAL_FUNC otype fmod(itype x, itype y) { \
+
132 return static_cast<otype>( \
+
133 __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
134 } \
+
135 METAL_FUNC otype fract(itype x) { \
+
136 return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
+
137 } \
+
138 METAL_FUNC otype frexp(itype x, thread int& exp) { \
+
139 return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
+
140 } \
+
141 METAL_FUNC otype ldexp(itype x, int k) { \
+
142 return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
+
143 } \
+
144 METAL_FUNC otype log(itype x) { \
+
145 return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
+
146 } \
+
147 METAL_FUNC otype log10(itype x) { \
+
148 return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
+
149 } \
+
150 METAL_FUNC otype log2(itype x) { \
+
151 return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
+
152 } \
+
153 METAL_FUNC otype max(itype x, itype y) { \
+
154 return static_cast<otype>( \
+
155 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
156 } \
+
157 METAL_FUNC otype max3(itype x, itype y, itype z) { \
+
158 return static_cast<otype>(__metal_fmax3( \
+
159 static_cast<ctype>(x), \
+
160 static_cast<ctype>(y), \
+
161 static_cast<ctype>(z), \
+
162 mfast)); \
+
163 } \
+
164 METAL_FUNC otype median3(itype x, itype y, itype z) { \
+
165 return static_cast<otype>(__metal_fmedian3( \
+
166 static_cast<ctype>(x), \
+
167 static_cast<ctype>(y), \
+
168 static_cast<ctype>(z), \
+
169 mfast)); \
+
170 } \
+
171 METAL_FUNC otype min(itype x, itype y) { \
+
172 return static_cast<otype>( \
+
173 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
174 } \
+
175 METAL_FUNC otype min3(itype x, itype y, itype z) { \
+
176 return static_cast<otype>(__metal_fmin3( \
+
177 static_cast<ctype>(x), \
+
178 static_cast<ctype>(y), \
+
179 static_cast<ctype>(z), \
+
180 mfast)); \
+
181 } \
+
182 METAL_FUNC otype nextafter(itype x, itype y) { \
+
183 return static_cast<otype>( \
+
184 __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
+
185 } \
+
186 METAL_FUNC otype pow(itype x, itype y) { \
+
187 return static_cast<otype>( \
+
188 __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
189 } \
+
190 METAL_FUNC otype powr(itype x, itype y) { \
+
191 return static_cast<otype>( \
+
192 __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
193 } \
+
194 METAL_FUNC otype rint(itype x) { \
+
195 return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
+
196 } \
+
197 METAL_FUNC otype round(itype x) { \
+
198 return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
+
199 } \
+
200 METAL_FUNC otype rsqrt(itype x) { \
+
201 return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
+
202 } \
+
203 METAL_FUNC otype sin(itype x) { \
+
204 return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
+
205 } \
+
206 METAL_FUNC otype sinh(itype x) { \
+
207 return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
+
208 } \
+
209 METAL_FUNC otype sinpi(itype x) { \
+
210 return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
+
211 } \
+
212 METAL_FUNC otype sqrt(itype x) { \
+
213 return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
+
214 } \
+
215 METAL_FUNC otype tan(itype x) { \
+
216 return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
+
217 } \
+
218 METAL_FUNC otype tanh(itype x) { \
+
219 return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
+
220 } \
+
221 METAL_FUNC otype tanpi(itype x) { \
+
222 return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
+
223 } \
+
224 METAL_FUNC otype trunc(itype x) { \
+
225 return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
+
226 }
+
+
227
+
228namespace metal {
+
229
+ + + +
233 float,
+
234 __METAL_MAYBE_FAST_MATH__);
+
235
+
+
236namespace fast {
+
237
+ + + +
241 float,
+
242 __METAL_FAST_MATH__);
+
243
+
244} // namespace fast
+
+
245
+
+
246namespace precise {
+
247
+ + + +
251 float,
+
252 __METAL_PRECISE_MATH__);
+
253
+
254} // namespace precise
+
+
255
+
256} // namespace metal
+
257
+
259// Metal simd for bfloat16
+
261
+
262#define instantiate_metal_simd_comm_funcs( \
+
263 itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
+
264 \
+
265 METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
+
266 return ctype_to_otype( \
+
267 __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
+
268 } \
+
269 \
+
270 METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
+
271 return ctype_to_otype( \
+
272 __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
+
273 } \
+
274 \
+
275 METAL_FUNC otype simd_shuffle_and_fill_down( \
+
276 itype data, itype filling_data, ushort delta, ushort modulo) { \
+
277 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
+
278 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
+
279 } \
+
280 \
+
281 METAL_FUNC otype simd_shuffle_and_fill_down( \
+
282 itype data, itype filling_data, ushort delta) { \
+
283 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
+
284 itype_to_ctype(data), \
+
285 itype_to_ctype(filling_data), \
+
286 delta, \
+
287 __metal_get_simdgroup_size(ushort()))); \
+
288 } \
+
289 \
+
290 METAL_FUNC otype simd_shuffle_and_fill_up( \
+
291 itype data, itype filling_data, ushort delta, ushort modulo) { \
+
292 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
+
293 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
+
294 } \
+
295 \
+
296 METAL_FUNC otype simd_shuffle_and_fill_up( \
+
297 itype data, itype filling_data, ushort delta) { \
+
298 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
+
299 itype_to_ctype(data), \
+
300 itype_to_ctype(filling_data), \
+
301 delta, \
+
302 __metal_get_simdgroup_size(ushort()))); \
+
303 } \
+
304 \
+
305 METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
+
306 return ctype_to_otype( \
+
307 __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
+
308 } \
+
309 \
+
310 METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
+
311 return ctype_to_otype( \
+
312 __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
+
313 } \
+
314 \
+
315 METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
+
316 return ctype_to_otype( \
+
317 __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
+
318 } \
+
319 \
+
320 METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
+
321 return ctype_to_otype( \
+
322 __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
+
323 } \
+
324 \
+
325 METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
+
326 return ctype_to_otype( \
+
327 __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
+
328 }
+
329
+
+
330#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
+
331 \
+
332 METAL_FUNC otype simd_max(itype data) { \
+
333 return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
+
334 } \
+
335 \
+
336 METAL_FUNC otype simd_min(itype data) { \
+
337 return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
+
338 } \
+
339 \
+
340 METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
+
341 return static_cast<otype>( \
+
342 __metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
+
343 } \
+
344 \
+
345 METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
+
346 return static_cast<otype>( \
+
347 __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
+
348 } \
+
349 \
+
350 METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
+
351 return static_cast<otype>( \
+
352 __metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
+
353 } \
+
354 \
+
355 METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
+
356 return static_cast<otype>( \
+
357 __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
+
358 } \
+
359 \
+
360 METAL_FUNC otype simd_product(itype data) { \
+
361 return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
+
362 } \
+
363 \
+
364 METAL_FUNC otype simd_sum(itype data) { \
+
365 return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
+
366 } \
+
367 \
+
368 METAL_FUNC otype simd_xor(itype data) { \
+
369 return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
+
370 }
+
+
371
+
372#if defined(__HAVE_BFLOAT__)
+
373
+
374#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
+
375#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
+
376
+
377#else
+
378
+
379#define bfloat16_to_uint16(x) x.bits_
+
380#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
+
381
+
382#endif
+
383
+
384namespace metal {
+
385
+ + + +
389 uint16_t,
+ + + +
393
+
394} // namespace metal
+ +
#define uint16_to_bfloat16(x)
Definition bf16_math.h:380
+
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
Definition bf16_math.h:330
+
#define bfloat16_to_uint16(x)
Definition bf16_math.h:379
+
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
Definition bf16_math.h:35
+
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
Definition bf16_math.h:262
+
Definition bf16.h:265
+
Definition bf16.h:54
+
+ + + + diff --git a/docs/build/html/binary__two_8h.html b/docs/build/html/binary__two_8h.html new file mode 100644 index 000000000..9e8fa3dee --- /dev/null +++ b/docs/build/html/binary__two_8h.html @@ -0,0 +1,101 @@ + + + + + + + +MLX: mlx/backend/common/binary_two.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces
+
binary_two.h File Reference
+
+
+
#include "mlx/backend/common/binary.h"
+#include "mlx/backend/common/utils.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+
+ + + + diff --git a/docs/build/html/binary__two_8h_source.html b/docs/build/html/binary__two_8h_source.html new file mode 100644 index 000000000..f6d72e878 --- /dev/null +++ b/docs/build/html/binary__two_8h_source.html @@ -0,0 +1,646 @@ + + + + + + + +MLX: mlx/backend/common/binary_two.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
binary_two.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+ + +
7
+
8namespace mlx::core {
+
9
+
10namespace {
+
11
+
12template <typename T, typename U, typename Op>
+
13void binary_op_dims1(
+
14 const array& a,
+
15 const array& b,
+
16 array& out_a,
+
17 array& out_b,
+
18 Op op) {
+
19 const T* a_ptr = a.data<T>();
+
20 const T* b_ptr = b.data<T>();
+
21 U* dst_a = out_a.data<U>();
+
22 U* dst_b = out_b.data<U>();
+
23 size_t a_idx = 0;
+
24 size_t b_idx = 0;
+
25 for (size_t i = 0; i < out_a.size(); ++i) {
+
26 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
+
27 dst_a[i] = dst.first;
+
28 dst_b[i] = dst.second;
+
29 a_idx += a.strides()[0];
+
30 b_idx += b.strides()[0];
+
31 }
+
32}
+
33
+
34template <typename T, typename U, typename Op>
+
35void binary_op_dims1(
+
36 const array& a,
+
37 const array& b,
+
38 array& out_a,
+
39 array& out_b,
+
40 Op op,
+
41 int stride) {
+
42 const T* a_ptr = a.data<T>();
+
43 const T* b_ptr = b.data<T>();
+
44 U* dst_a = out_a.data<U>();
+
45 U* dst_b = out_b.data<U>();
+
46 size_t a_idx = 0;
+
47 size_t b_idx = 0;
+
48 for (size_t i = 0; i < a.shape()[0]; i++) {
+
49 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
+
50 a_idx += a.strides()[0];
+
51 b_idx += b.strides()[0];
+
52 dst_a += stride;
+
53 dst_b += stride;
+
54 }
+
55}
+
56
+
57template <typename T, typename U, typename Op>
+
58void binary_op_dims2(
+
59 const array& a,
+
60 const array& b,
+
61 array& out_a,
+
62 array& out_b,
+
63 Op op) {
+
64 const T* a_ptr = a.data<T>();
+
65 const T* b_ptr = b.data<T>();
+
66 U* dst_a = out_a.data<U>();
+
67 U* dst_b = out_b.data<U>();
+
68 size_t a_idx = 0;
+
69 size_t b_idx = 0;
+
70 size_t out_idx = 0;
+
71 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
72 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
73 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
+
74 dst_a[out_idx] = dst.first;
+
75 dst_b[out_idx++] = dst.second;
+
76 a_idx += a.strides()[1];
+
77 b_idx += b.strides()[1];
+
78 }
+
79 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
80 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
81 }
+
82}
+
83
+
84template <typename T, typename U, typename Op>
+
85void binary_op_dims2(
+
86 const array& a,
+
87 const array& b,
+
88 array& out_a,
+
89 array& out_b,
+
90 Op op,
+
91 int stride) {
+
92 const T* a_ptr = a.data<T>();
+
93 const T* b_ptr = b.data<T>();
+
94 U* dst_a = out_a.data<U>();
+
95 U* dst_b = out_b.data<U>();
+
96 size_t a_idx = 0;
+
97 size_t b_idx = 0;
+
98 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
99 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
100 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
+
101 a_idx += a.strides()[1];
+
102 b_idx += b.strides()[1];
+
103 dst_a += stride;
+
104 dst_b += stride;
+
105 }
+
106 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
107 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
108 }
+
109}
+
110
+
111template <typename T, typename U, typename Op>
+
112void binary_op_dims3(
+
113 const array& a,
+
114 const array& b,
+
115 array& out_a,
+
116 array& out_b,
+
117 Op op) {
+
118 const T* a_ptr = a.data<T>();
+
119 const T* b_ptr = b.data<T>();
+
120 U* dst_a = out_a.data<U>();
+
121 U* dst_b = out_b.data<U>();
+
122 size_t a_idx = 0;
+
123 size_t b_idx = 0;
+
124 size_t out_idx = 0;
+
125 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
126 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
127 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
128 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
+
129 dst_a[out_idx] = dst.first;
+
130 dst_b[out_idx++] = dst.second;
+
131 a_idx += a.strides()[2];
+
132 b_idx += b.strides()[2];
+
133 }
+
134 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
135 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
136 }
+
137 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
138 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
139 }
+
140}
+
141
+
142template <typename T, typename U, typename Op>
+
143void binary_op_dims4(
+
144 const array& a,
+
145 const array& b,
+
146 array& out_a,
+
147 array& out_b,
+
148 Op op) {
+
149 const T* a_ptr = a.data<T>();
+
150 const T* b_ptr = b.data<T>();
+
151 U* dst_a = out_a.data<U>();
+
152 U* dst_b = out_b.data<U>();
+
153 size_t a_idx = 0;
+
154 size_t b_idx = 0;
+
155 size_t out_idx = 0;
+
156 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
157 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
158 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
159 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
+
160 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
+
161 dst_a[out_idx] = dst.first;
+
162 dst_b[out_idx++] = dst.second;
+
163 a_idx += a.strides()[3];
+
164 b_idx += b.strides()[3];
+
165 }
+
166 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
+
167 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
+
168 }
+
169 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
170 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
171 }
+
172 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
173 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
174 }
+
175}
+
176
+
177template <typename T, typename U, typename Op>
+
178void binary_op_dispatch_dims(
+
179 const array& a,
+
180 const array& b,
+
181 array& out_a,
+
182 array& out_b,
+
183 Op op) {
+
184 switch (out_a.ndim()) {
+
185 case 1:
+
186 binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
+
187 return;
+
188 case 2:
+
189 binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
+
190 return;
+
191 case 3:
+
192 binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
+
193 return;
+
194 case 4:
+
195 binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
+
196 return;
+
197 }
+
198
+
199 const T* a_ptr = a.data<T>();
+
200 const T* b_ptr = b.data<T>();
+
201 U* dst_a = out_a.data<U>();
+
202 U* dst_b = out_b.data<U>();
+
203 for (size_t i = 0; i < out_a.size(); i++) {
+
204 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
205 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
206 std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
+
207 }
+
208}
+
209
+
210template <typename T, typename U, typename Op>
+
211void binary_op_dispatch_dims(
+
212 const array& a,
+
213 const array& b,
+
214 array& out_a,
+
215 array& out_b,
+
216 Op op,
+
217 int dim,
+
218 int stride) {
+
219 // Number of dimensions to loop over for vectorized ops
+
220 switch (dim) {
+
221 case 1:
+
222 binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
+
223 return;
+
224 case 2:
+
225 binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
+
226 return;
+
227 }
+
228
+
229 const T* a_ptr = a.data<T>();
+
230 const T* b_ptr = b.data<T>();
+
231 U* dst_a = out_a.data<U>();
+
232 U* dst_b = out_b.data<U>();
+
233 for (size_t i = 0; i < out_a.size(); i += stride) {
+
234 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
235 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
236 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
+
237 dst_a += stride;
+
238 dst_b += stride;
+
239 }
+
240}
+
241
+
242template <
+
243 typename T,
+
244 typename U,
+
245 typename Op,
+
246 typename OpSV,
+
247 typename OpVS,
+
248 typename OpVV>
+
249void binary_op(
+
250 const array& a,
+
251 const array& b,
+
252 array& out_a,
+
253 array& out_b,
+
254 Op op,
+
255 OpSV opsv,
+
256 OpVS opvs,
+
257 OpVV opvv) {
+
258 auto bopt = get_binary_op_type(a, b);
+
259 set_binary_op_output_data(a, b, out_a, bopt);
+
260 set_binary_op_output_data(a, b, out_b, bopt);
+
261
+
262 // The full computation is scalar scalar so call the base op once
+
263 if (bopt == BinaryOpType::ScalarScalar) {
+
264 std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
+
265 op(*a.data<T>(), *b.data<T>());
+
266 return;
+
267 }
+
268
+
269 // The full computation is scalar vector so delegate to the op
+
270 if (bopt == BinaryOpType::ScalarVector) {
+
271 opsv(
+
272 a.data<T>(),
+
273 b.data<T>(),
+
274 out_a.data<U>(),
+
275 out_b.data<U>(),
+
276 b.data_size());
+
277 return;
+
278 }
+
279
+
280 // The full computation is vector scalar so delegate to the op
+
281 if (bopt == BinaryOpType::VectorScalar) {
+
282 opvs(
+
283 a.data<T>(),
+
284 b.data<T>(),
+
285 out_a.data<U>(),
+
286 out_b.data<U>(),
+
287 a.data_size());
+
288 return;
+
289 }
+
290
+
291 // The full computation is vector vector so delegate to the op
+
292 if (bopt == BinaryOpType::VectorVector) {
+
293 opvv(
+
294 a.data<T>(),
+
295 b.data<T>(),
+
296 out_a.data<U>(),
+
297 out_b.data<U>(),
+
298 out_a.size());
+
299 return;
+
300 }
+
301
+
302 // General computation so let's try to optimize
+
303
+
304 // Get the left-most dim such that the array is row contiguous after
+
305 auto& strides = out_a.strides();
+
306 auto leftmost_rc_dim = [&strides](const array& arr) {
+
307 int d = arr.ndim() - 1;
+
308 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
+
309 }
+
310 return d + 1;
+
311 };
+
312 auto a_rc_dim = leftmost_rc_dim(a);
+
313 auto b_rc_dim = leftmost_rc_dim(b);
+
314
+
315 // Get the left-most dim such that the array is a broadcasted "scalar" after
+
316 auto leftmost_s_dim = [](const array& arr) {
+
317 int d = arr.ndim() - 1;
+
318 for (; d >= 0 && arr.strides()[d] == 0; d--) {
+
319 }
+
320 return d + 1;
+
321 };
+
322 auto a_s_dim = leftmost_s_dim(a);
+
323 auto b_s_dim = leftmost_s_dim(b);
+
324
+
325 auto ndim = out_a.ndim();
+
326
+
327 // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
+
328 int dim = ndim;
+
329 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
+
330 bopt = BinaryOpType::VectorVector;
+
331 dim = d;
+
332 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
+
333 // contiguous
+
334 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
+
335 bopt = BinaryOpType::VectorScalar;
+
336 dim = d;
+
337 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
+
338 // contiguous
+
339 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
+
340 bopt = BinaryOpType::ScalarVector;
+
341 dim = d;
+
342 }
+
343
+
344 // Can be sure dim > 0 since otherwise we would have used one of the fully
+
345 // contiguous methods above. Except for the case that the flags do not
+
346 // correspond to the underlying contiguity.
+
347 size_t stride;
+
348 if (dim == 0 || strides[dim - 1] < 16) {
+
349 stride = 1;
+
350 bopt = BinaryOpType::General;
+
351 dim = ndim;
+
352 } else {
+
353 stride = strides[dim - 1];
+
354 }
+
355
+
356 switch (bopt) {
+
357 case BinaryOpType::VectorVector:
+
358 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
+
359 break;
+
360 case BinaryOpType::VectorScalar:
+
361 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
+
362 break;
+
363 case BinaryOpType::ScalarVector:
+
364 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
+
365 break;
+
366 default:
+
367 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
+
368 break;
+
369 }
+
370}
+
371
+
372template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
+
373void binary_op(
+
374 const array& a,
+
375 const array& b,
+
376 std::vector<array>& outputs,
+
377 Op op,
+
378 OpSV opsv,
+
379 OpVS opvs,
+
380 OpVV opvv) {
+
381 // TODO: The following mess of constexpr evaluations can probably be achieved
+
382 // with template specializations and overloading. Would it be simpler?
+
383
+
384 if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
+
385 if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
+
386 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
387 // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
+
388 binary_op<T, T>(
+
389 a,
+
390 b,
+
391 outputs[0],
+
392 outputs[1],
+
393 op,
+
394 DefaultScalarVector<T, T, Op>(op),
+
395 DefaultVectorScalar<T, T, Op>(op),
+
396 DefaultVectorVector<T, T, Op>(op));
+
397 } else {
+
398 // opsv and opvs were UseDefaultBinaryOp
+
399 binary_op<T, T>(
+
400 a,
+
401 b,
+
402 outputs[0],
+
403 outputs[1],
+
404 op,
+
405 DefaultScalarVector<T, T, Op>(op),
+
406 DefaultVectorScalar<T, T, Op>(op),
+
407 opvv);
+
408 }
+
409 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
410 // opsv and opvv were UseDefaultBinaryOp
+
411 binary_op<T, T>(
+
412 a,
+
413 b,
+
414 outputs[0],
+
415 outputs[1],
+
416 op,
+
417 DefaultScalarVector<T, T, Op>(op),
+
418 opvs,
+
419 DefaultVectorVector<T, T, Op>(op));
+
420 } else {
+
421 // opsv was UseDefaultBinaryOp
+
422 binary_op<T, T>(
+
423 a,
+
424 b,
+
425 outputs[0],
+
426 outputs[1],
+
427 op,
+
428 DefaultScalarVector<T, T, Op>(op),
+
429 opvs,
+
430 opvv);
+
431 }
+
432 } else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
+
433 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
434 // opvs and opvv were UseDefaultBinaryOp
+
435 binary_op<T, T>(
+
436 a,
+
437 b,
+
438 outputs[0],
+
439 outputs[1],
+
440 op,
+
441 opsv,
+
442 DefaultVectorScalar<T, T, Op>(op),
+
443 DefaultVectorVector<T, T, Op>(op));
+
444 } else {
+
445 // opvs was UseDefaultBinaryOp
+
446 binary_op<T, T>(
+
447 a,
+
448 b,
+
449 outputs[0],
+
450 outputs[1],
+
451 op,
+
452 opsv,
+
453 DefaultVectorScalar<T, T, Op>(op),
+
454 opvv);
+
455 }
+
456 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
457 // opvv was UseDefaultBinaryOp
+
458 binary_op<T, T>(
+
459 a,
+
460 b,
+
461 outputs[0],
+
462 outputs[1],
+
463 op,
+
464 opsv,
+
465 opvs,
+
466 DefaultVectorVector<T, T, Op>(op));
+
467 } else {
+
468 // All ops provided
+
469 binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
+
470 }
+
471}
+
472
+
473template <typename T, typename Op>
+
474void binary_op(
+
475 const array& a,
+
476 const array& b,
+
477 std::vector<array>& outputs,
+
478 Op op) {
+
479 DefaultScalarVector<T, T, Op> opsv(op);
+
480 DefaultVectorScalar<T, T, Op> opvs(op);
+
481 DefaultVectorVector<T, T, Op> opvv(op);
+
482 binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
+
483}
+
484
+
485template <typename... Ops>
+
486void binary(
+
487 const array& a,
+
488 const array& b,
+
489 std::vector<array>& outputs,
+
490 Ops... ops) {
+
491 switch (outputs[0].dtype()) {
+
492 case bool_:
+
493 binary_op<bool>(a, b, outputs, ops...);
+
494 break;
+
495 case uint8:
+
496 binary_op<uint8_t>(a, b, outputs, ops...);
+
497 break;
+
498 case uint16:
+
499 binary_op<uint16_t>(a, b, outputs, ops...);
+
500 break;
+
501 case uint32:
+
502 binary_op<uint32_t>(a, b, outputs, ops...);
+
503 break;
+
504 case uint64:
+
505 binary_op<uint64_t>(a, b, outputs, ops...);
+
506 break;
+
507 case int8:
+
508 binary_op<int8_t>(a, b, outputs, ops...);
+
509 break;
+
510 case int16:
+
511 binary_op<int16_t>(a, b, outputs, ops...);
+
512 break;
+
513 case int32:
+
514 binary_op<int32_t>(a, b, outputs, ops...);
+
515 break;
+
516 case int64:
+
517 binary_op<int64_t>(a, b, outputs, ops...);
+
518 break;
+
519 case float16:
+
520 binary_op<float16_t>(a, b, outputs, ops...);
+
521 break;
+
522 case float32:
+
523 binary_op<float>(a, b, outputs, ops...);
+
524 break;
+
525 case bfloat16:
+
526 binary_op<bfloat16_t>(a, b, outputs, ops...);
+
527 break;
+
528 case complex64:
+
529 binary_op<complex64_t>(a, b, outputs, ops...);
+
530 break;
+
531 }
+
532}
+
533
+
534} // namespace
+
535
+
536} // namespace mlx::core
+ + +
Op op
Definition binary.h:139
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_kernel-members.html b/docs/build/html/class_m_p_s_1_1_kernel-members.html new file mode 100644 index 000000000..f12b0a050 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_kernel-members.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::Kernel Member List
+
+
+ +

This is the complete list of members for MPS::Kernel, including all inherited members.

+ + + +
device() constMPS::Kernel
label() constMPS::Kernel
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_kernel.html b/docs/build/html/class_m_p_s_1_1_kernel.html new file mode 100644 index 000000000..de2a0c584 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_kernel.html @@ -0,0 +1,144 @@ + + + + + + + +MLX: MPS::Kernel Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
MPS::Kernel Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::Kernel:
+
+
+ +
+ + + + + + +

+Public Member Functions

NS::String * label () const
 
MTL::Device * device () const
 
+

Member Function Documentation

+ +

◆ device()

+ +
+
+ + + + + + + +
_MTL_INLINE MTL::Device * MPS::Kernel::device () const
+
+ +
+
+ +

◆ label()

+ +
+
+ + + + + + + +
_MTL_INLINE NS::String * MPS::Kernel::label () const
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_kernel.png b/docs/build/html/class_m_p_s_1_1_kernel.png new file mode 100644 index 000000000..0ed62c014 Binary files /dev/null and b/docs/build/html/class_m_p_s_1_1_kernel.png differ diff --git a/docs/build/html/class_m_p_s_1_1_matrix-members.html b/docs/build/html/class_m_p_s_1_1_matrix-members.html new file mode 100644 index 000000000..48ab345d3 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::Matrix Member List
+
+
+ +

This is the complete list of members for MPS::Matrix, including all inherited members.

+ + + + +
alloc()MPS::Matrixstatic
init(MTL::Buffer *buffer, MatrixDescriptor *descriptor)MPS::Matrix
init(const MTL::Buffer *buffer, MatrixDescriptor *descriptor)MPS::Matrix
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix.html b/docs/build/html/class_m_p_s_1_1_matrix.html new file mode 100644 index 000000000..8fc4607bd --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix.html @@ -0,0 +1,183 @@ + + + + + + + +MLX: MPS::Matrix Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Static Public Member Functions | +List of all members
+
MPS::Matrix Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::Matrix:
+
+
+ +
+ + + + + + +

+Public Member Functions

Matrixinit (MTL::Buffer *buffer, MatrixDescriptor *descriptor)
 
Matrixinit (const MTL::Buffer *buffer, MatrixDescriptor *descriptor)
 
+ + + +

+Static Public Member Functions

static class Matrixalloc ()
 
+

Member Function Documentation

+ +

◆ alloc()

+ +
+
+ + + + + +
+ + + + + + + +
_MTL_INLINE Matrix * MPS::Matrix::alloc ()
+
+static
+
+ +
+
+ +

◆ init() [1/2]

+ +
+
+ + + + + + + + + + + +
_MTL_INLINE Matrix * MPS::Matrix::init (const MTL::Buffer * buffer,
MatrixDescriptor * descriptor )
+
+ +
+
+ +

◆ init() [2/2]

+ +
+
+ + + + + + + + + + + +
_MTL_INLINE Matrix * MPS::Matrix::init (MTL::Buffer * buffer,
MatrixDescriptor * descriptor )
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix.png b/docs/build/html/class_m_p_s_1_1_matrix.png new file mode 100644 index 000000000..a14d0bf98 Binary files /dev/null and b/docs/build/html/class_m_p_s_1_1_matrix.png differ diff --git a/docs/build/html/class_m_p_s_1_1_matrix_descriptor-members.html b/docs/build/html/class_m_p_s_1_1_matrix_descriptor-members.html new file mode 100644 index 000000000..dee3102ef --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_descriptor-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::MatrixDescriptor Member List
+
+
+ +

This is the complete list of members for MPS::MatrixDescriptor, including all inherited members.

+ + + + +
matrixDescriptor(NS::UInteger rows, NS::UInteger columns, NS::UInteger rowBytes, NS::UInteger dataType)MPS::MatrixDescriptorstatic
matrixDescriptor(NS::UInteger rows, NS::UInteger columns, NS::UInteger matrices, NS::UInteger rowBytes, NS::UInteger matrixBytes, NS::UInteger dataType)MPS::MatrixDescriptorstatic
rows() constMPS::MatrixDescriptor
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_descriptor.html b/docs/build/html/class_m_p_s_1_1_matrix_descriptor.html new file mode 100644 index 000000000..1b88da36a --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_descriptor.html @@ -0,0 +1,221 @@ + + + + + + + +MLX: MPS::MatrixDescriptor Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Static Public Member Functions | +List of all members
+
MPS::MatrixDescriptor Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::MatrixDescriptor:
+
+
+ +
+ + + + +

+Public Member Functions

NS::UInteger rows () const
 
+ + + + + +

+Static Public Member Functions

static class MatrixDescriptormatrixDescriptor (NS::UInteger rows, NS::UInteger columns, NS::UInteger rowBytes, NS::UInteger dataType)
 
static class MatrixDescriptormatrixDescriptor (NS::UInteger rows, NS::UInteger columns, NS::UInteger matrices, NS::UInteger rowBytes, NS::UInteger matrixBytes, NS::UInteger dataType)
 
+

Member Function Documentation

+ +

◆ matrixDescriptor() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE MatrixDescriptor * MPS::MatrixDescriptor::matrixDescriptor (NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType )
+
+static
+
+ +
+
+ +

◆ matrixDescriptor() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE MatrixDescriptor * MPS::MatrixDescriptor::matrixDescriptor (NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType )
+
+static
+
+ +
+
+ +

◆ rows()

+ +
+
+ + + + + + + +
_MTL_INLINE NS::UInteger MPS::MatrixDescriptor::rows () const
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_descriptor.png b/docs/build/html/class_m_p_s_1_1_matrix_descriptor.png new file mode 100644 index 000000000..661b7f470 Binary files /dev/null and b/docs/build/html/class_m_p_s_1_1_matrix_descriptor.png differ diff --git a/docs/build/html/class_m_p_s_1_1_matrix_multiplication-members.html b/docs/build/html/class_m_p_s_1_1_matrix_multiplication-members.html new file mode 100644 index 000000000..37ae4b084 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_multiplication-members.html @@ -0,0 +1,98 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::MatrixMultiplication Member List
+
+
+ +

This is the complete list of members for MPS::MatrixMultiplication, including all inherited members.

+ + + + + + + + + +
alloc()MPS::MatrixMultiplicationstatic
encodeToCommandBuffer(MTL::CommandBuffer *commandBuffer, Matrix *leftMatrix, Matrix *rightMatrix, Matrix *resultMatrix)MPS::MatrixMultiplication
init(MTL::Device *device, bool transposeLeft, bool transposeRight, NS::UInteger resultRows, NS::UInteger resultColumns, NS::UInteger interiorColumns, double alpha, double beta)MPS::MatrixMultiplication
setBatchSize(NS::UInteger batchSize)MPS::MatrixMultiplication
setBatchStart(NS::UInteger batchStart)MPS::MatrixMultiplication
setLeftMatrixOrigin(MTL::Origin origin)MPS::MatrixMultiplication
setResultMatrixOrigin(MTL::Origin origin)MPS::MatrixMultiplication
setRightMatrixOrigin(MTL::Origin origin)MPS::MatrixMultiplication
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_multiplication.html b/docs/build/html/class_m_p_s_1_1_matrix_multiplication.html new file mode 100644 index 000000000..f751a2bd4 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_multiplication.html @@ -0,0 +1,318 @@ + + + + + + + +MLX: MPS::MatrixMultiplication Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Static Public Member Functions | +List of all members
+
MPS::MatrixMultiplication Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::MatrixMultiplication:
+
+
+ +
+ + + + + + + + + + + + + + + + +

+Public Member Functions

MatrixMultiplicationinit (MTL::Device *device, bool transposeLeft, bool transposeRight, NS::UInteger resultRows, NS::UInteger resultColumns, NS::UInteger interiorColumns, double alpha, double beta)
 
void encodeToCommandBuffer (MTL::CommandBuffer *commandBuffer, Matrix *leftMatrix, Matrix *rightMatrix, Matrix *resultMatrix)
 
void setLeftMatrixOrigin (MTL::Origin origin)
 
void setRightMatrixOrigin (MTL::Origin origin)
 
void setResultMatrixOrigin (MTL::Origin origin)
 
void setBatchStart (NS::UInteger batchStart)
 
void setBatchSize (NS::UInteger batchSize)
 
+ + + +

+Static Public Member Functions

static class MatrixMultiplicationalloc ()
 
+

Member Function Documentation

+ +

◆ alloc()

+ +
+
+ + + + + +
+ + + + + + + +
_MTL_INLINE MatrixMultiplication * MPS::MatrixMultiplication::alloc ()
+
+static
+
+ +
+
+ +

◆ encodeToCommandBuffer()

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::encodeToCommandBuffer (MTL::CommandBuffer * commandBuffer,
Matrix * leftMatrix,
Matrix * rightMatrix,
Matrix * resultMatrix )
+
+ +
+
+ +

◆ init()

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE MatrixMultiplication * MPS::MatrixMultiplication::init (MTL::Device * device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta )
+
+ +
+
+ +

◆ setBatchSize()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setBatchSize (NS::UInteger batchSize)
+
+ +
+
+ +

◆ setBatchStart()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setBatchStart (NS::UInteger batchStart)
+
+ +
+
+ +

◆ setLeftMatrixOrigin()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setLeftMatrixOrigin (MTL::Origin origin)
+
+ +
+
+ +

◆ setResultMatrixOrigin()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setResultMatrixOrigin (MTL::Origin origin)
+
+ +
+
+ +

◆ setRightMatrixOrigin()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setRightMatrixOrigin (MTL::Origin origin)
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_multiplication.png b/docs/build/html/class_m_p_s_1_1_matrix_multiplication.png new file mode 100644 index 000000000..7da3f69a9 Binary files /dev/null and b/docs/build/html/class_m_p_s_1_1_matrix_multiplication.png differ diff --git a/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication-members.html b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication-members.html new file mode 100644 index 000000000..abb569f03 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::MatrixVectorMultiplication Member List
+
+
+ +

This is the complete list of members for MPS::MatrixVectorMultiplication, including all inherited members.

+ + + + +
alloc()MPS::MatrixVectorMultiplicationstatic
encodeToCommandBuffer(MTL::CommandBuffer *commandBuffer, Matrix *inputMatrix, Vector *inputVector, Vector *resultVector)MPS::MatrixVectorMultiplication
init(MTL::Device *device, bool transpose, NS::UInteger rows, NS::UInteger columns, double alpha, double beta)MPS::MatrixVectorMultiplication
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.html b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.html new file mode 100644 index 000000000..074f379d7 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.html @@ -0,0 +1,213 @@ + + + + + + + +MLX: MPS::MatrixVectorMultiplication Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Static Public Member Functions | +List of all members
+
MPS::MatrixVectorMultiplication Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::MatrixVectorMultiplication:
+
+
+ +
+ + + + + + +

+Public Member Functions

MatrixVectorMultiplicationinit (MTL::Device *device, bool transpose, NS::UInteger rows, NS::UInteger columns, double alpha, double beta)
 
void encodeToCommandBuffer (MTL::CommandBuffer *commandBuffer, Matrix *inputMatrix, Vector *inputVector, Vector *resultVector)
 
+ + + +

+Static Public Member Functions

static class MatrixVectorMultiplicationalloc ()
 
+

Member Function Documentation

+ +

◆ alloc()

+ +
+
+ + + + + +
+ + + + + + + +
_MTL_INLINE MatrixVectorMultiplication * MPS::MatrixVectorMultiplication::alloc ()
+
+static
+
+ +
+
+ +

◆ encodeToCommandBuffer()

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE void MPS::MatrixVectorMultiplication::encodeToCommandBuffer (MTL::CommandBuffer * commandBuffer,
Matrix * inputMatrix,
Vector * inputVector,
Vector * resultVector )
+
+ +
+
+ +

◆ init()

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE MatrixVectorMultiplication * MPS::MatrixVectorMultiplication::init (MTL::Device * device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta )
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.png b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.png new file mode 100644 index 000000000..10e54c8bd Binary files /dev/null and b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.png differ diff --git a/docs/build/html/class_m_p_s_1_1_vector-members.html b/docs/build/html/class_m_p_s_1_1_vector-members.html new file mode 100644 index 000000000..31e25f93d --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_vector-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::Vector Member List
+
+
+ +

This is the complete list of members for MPS::Vector, including all inherited members.

+ + + + +
alloc()MPS::Vectorstatic
init(MTL::Buffer *buffer, VectorDescriptor *descriptor)MPS::Vector
init(const MTL::Buffer *buffer, VectorDescriptor *descriptor)MPS::Vector
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_vector.html b/docs/build/html/class_m_p_s_1_1_vector.html new file mode 100644 index 000000000..9e1039ead --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_vector.html @@ -0,0 +1,183 @@ + + + + + + + +MLX: MPS::Vector Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Static Public Member Functions | +List of all members
+
MPS::Vector Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::Vector:
+
+
+ +
+ + + + + + +

+Public Member Functions

Vectorinit (MTL::Buffer *buffer, VectorDescriptor *descriptor)
 
Vectorinit (const MTL::Buffer *buffer, VectorDescriptor *descriptor)
 
+ + + +

+Static Public Member Functions

static class Vectoralloc ()
 
+

Member Function Documentation

+ +

◆ alloc()

+ +
+
+ + + + + +
+ + + + + + + +
_MTL_INLINE Vector * MPS::Vector::alloc ()
+
+static
+
+ +
+
+ +

◆ init() [1/2]

+ +
+
+ + + + + + + + + + + +
_MTL_INLINE Vector * MPS::Vector::init (const MTL::Buffer * buffer,
VectorDescriptor * descriptor )
+
+ +
+
+ +

◆ init() [2/2]

+ +
+
+ + + + + + + + + + + +
_MTL_INLINE Vector * MPS::Vector::init (MTL::Buffer * buffer,
VectorDescriptor * descriptor )
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_vector.png b/docs/build/html/class_m_p_s_1_1_vector.png new file mode 100644 index 000000000..e7adea362 Binary files /dev/null and b/docs/build/html/class_m_p_s_1_1_vector.png differ diff --git a/docs/build/html/class_m_p_s_1_1_vector_descriptor-members.html b/docs/build/html/class_m_p_s_1_1_vector_descriptor-members.html new file mode 100644 index 000000000..607989e29 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_vector_descriptor-members.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::VectorDescriptor Member List
+
+
+ +

This is the complete list of members for MPS::VectorDescriptor, including all inherited members.

+ + + +
vectorDescriptor(NS::UInteger length, NS::UInteger dataType)MPS::VectorDescriptorstatic
vectorDescriptor(NS::UInteger length, NS::UInteger vectors, NS::UInteger vectorBytes, NS::UInteger dataType)MPS::VectorDescriptorstatic
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_vector_descriptor.html b/docs/build/html/class_m_p_s_1_1_vector_descriptor.html new file mode 100644 index 000000000..681f3428e --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_vector_descriptor.html @@ -0,0 +1,178 @@ + + + + + + + +MLX: MPS::VectorDescriptor Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Static Public Member Functions | +List of all members
+
MPS::VectorDescriptor Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::VectorDescriptor:
+
+
+ +
+ + + + + + +

+Static Public Member Functions

static class VectorDescriptorvectorDescriptor (NS::UInteger length, NS::UInteger dataType)
 
static class VectorDescriptorvectorDescriptor (NS::UInteger length, NS::UInteger vectors, NS::UInteger vectorBytes, NS::UInteger dataType)
 
+

Member Function Documentation

+ +

◆ vectorDescriptor() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
_MTL_INLINE VectorDescriptor * MPS::VectorDescriptor::vectorDescriptor (NS::UInteger length,
NS::UInteger dataType )
+
+static
+
+ +
+
+ +

◆ vectorDescriptor() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE VectorDescriptor * MPS::VectorDescriptor::vectorDescriptor (NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType )
+
+static
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_vector_descriptor.png b/docs/build/html/class_m_p_s_1_1_vector_descriptor.png new file mode 100644 index 000000000..00b2efd0e Binary files /dev/null and b/docs/build/html/class_m_p_s_1_1_vector_descriptor.png differ diff --git a/docs/build/html/classes.html b/docs/build/html/classes.html new file mode 100644 index 000000000..e2a481f31 --- /dev/null +++ b/docs/build/html/classes.html @@ -0,0 +1,152 @@ + + + + + + + +MLX: Class Index + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + +
+ +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ +
+
Class Index
+
+
+
A | B | C | D | E | F | G | I | K | L | M | N | O | P | Q | R | S | T | U | V | W | _
+
+
+
A
+
Abs
Abs (mlx::core)
Abs (mlx::core::detail)
AccumHelper (mlx::steel)
Add
Add (mlx::core)
Add (mlx::core::detail)
add_vec (pocketfft::detail)
add_vec< cmplx< T > > (pocketfft::detail)
AddMM (mlx::core)
aligned_allocator (pocketfft::detail::threading)
Allocator (mlx::core::allocator)
And
Arange (mlx::core)
ArcCos
ArcCos (mlx::core)
ArcCos (mlx::core::detail)
ArcCosh
ArcCosh (mlx::core)
ArcCosh (mlx::core::detail)
ArcSin
ArcSin (mlx::core)
ArcSin (mlx::core::detail)
ArcSinh
ArcSinh (mlx::core)
ArcSinh (mlx::core::detail)
ArcTan
ArcTan (mlx::core)
ArcTan (mlx::core::detail)
ArcTan2
ArcTan2 (mlx::core)
ArcTan2 (mlx::core::detail)
ArcTanh
ArcTanh (mlx::core)
ArcTanh (mlx::core::detail)
ArgPartition (mlx::core)
ArgReduce (mlx::core)
ArgSort (mlx::core)
arr (pocketfft::detail)
arr_info (pocketfft::detail)
array (mlx::core)
array::ArrayIterator (mlx::core)
AsStrided (mlx::core)
AsType (mlx::core)
+
+
B
+
_MLX_BFloat16::bits_to_bfloat_struct
BitwiseAnd
BitwiseAnd (mlx::core::detail)
BitwiseBinary (mlx::core)
BitwiseOr
BitwiseOr (mlx::core::detail)
BitwiseXor
BitwiseXor (mlx::core::detail)
BlockLoader (mlx::steel)
BlockMaskedMM (mlx::core)
BlockMMA (mlx::steel)
BlockSparseMM (mlx::core)
BlockSwizzle (mlx::steel)
bool4_or_uint
Broadcast (mlx::core)
Buffer (mlx::core::allocator)
+
+
C
+
Ceil
Ceil (mlx::core)
Ceil (mlx::core::detail)
cfftp (pocketfft::detail)
ChannelHelper (mlx::steel)
ChannelHelper< 1 > (mlx::steel)
ChannelHelper< 2 > (mlx::steel)
ChannelHelper< 3 > (mlx::steel)
ChannelHelper< 4 > (mlx::steel)
cmplx (pocketfft::detail)
cndarr (pocketfft::detail)
CommandEncoder (mlx::core::metal)
CommonAllocator (mlx::core::allocator)
Compiled (mlx::core)
complex128_t (mlx::core)
complex64_t
complex64_t (mlx::core)
Concatenate (mlx::core)
concurrent_queue (pocketfft::detail::threading)
CommandEncoder::ConcurrentContext (mlx::core::metal)
Conjugate
Conjugate (mlx::core)
Conjugate (mlx::core::detail)
Conv2DGeneralBaseInfo (mlx::steel)
Conv2DGeneralJumpParams (mlx::steel)
Conv2DInputBlockLoaderGeneral (mlx::steel)
Conv2DInputBlockLoaderLargeFilter (mlx::steel)
Conv2DInputBlockLoaderSmallChannels (mlx::steel)
Conv2DInputBlockLoaderSmallFilter (mlx::steel)
Conv2DWeightBlockLoader (mlx::steel)
Conv2DWeightBlockLoaderGeneral (mlx::steel)
Conv2DWeightBlockLoaderSmallChannels (mlx::steel)
Convolution (mlx::core)
Copy (mlx::core)
Cos
Cos (mlx::core)
Cos (mlx::core::detail)
Cosh
Cosh (mlx::core)
Cosh (mlx::core::detail)
Custom (mlx::core::fast)
CustomVJP (mlx::core)
+
+
D
+
array::Data (mlx::core)
Depends (mlx::core)
Device (mlx::core)
Device (mlx::core::metal)
Divide
Divide (mlx::core::detail)
Divide (mlx::core)
DivMod (mlx::core)
Dtype (mlx::core)
+
+
E
+
Equal
Equal (mlx::core::detail)
Equal (mlx::core)
Erf
Erf (mlx::core::detail)
Erf (mlx::core)
ErfInv
ErfInv (mlx::core::detail)
ErfInv (mlx::core)
Event (mlx::core)
ExecC2C (pocketfft::detail)
ExecDcst (pocketfft::detail)
ExecHartley (pocketfft::detail)
ExecR2R (pocketfft::detail)
Exp
Exp (mlx::core::detail)
Exp (mlx::core)
Expm1
Expm1 (mlx::core::detail)
Expm1 (mlx::core)
+
+
F
+
FFT (mlx::core)
fftblue (pocketfft::detail)
FileReader (mlx::core::io)
FileWriter (mlx::core::io)
array::Flags (mlx::core)
Floor
Floor (mlx::core::detail)
Floor (mlx::core)
Full (mlx::core)
+
+
G
+
Gather (mlx::core)
GEMMAddMMParams (mlx::steel)
GEMMKernel (mlx::steel)
GEMMParams (mlx::steel)
GEMMSpiltKParams (mlx::steel)
Greater
Greater (mlx::core::detail)
Greater (mlx::core)
GreaterEqual
GreaterEqual (mlx::core::detail)
GreaterEqual (mlx::core)
+
+
I
+
ImplicitGemmConv2DParams (mlx::steel)
Indices
IntOrFloat (mlx::core::detail)
InTracing (mlx::core::detail)
Inverse (mlx::core)
+
+
K
+
Kernel (MPS)
KeySequence (mlx::core::random)
+
+
L
+
latch (pocketfft::detail::threading)
LayerNorm (mlx::core::fast)
LayerNormVJP (mlx::core::fast)
LeftShift
LeftShift (mlx::core::detail)
Less
Less (mlx::core::detail)
Less (mlx::core)
LessEqual
LessEqual (mlx::core::detail)
LessEqual (mlx::core)
Limits
Limits< bfloat16_t >
Limits< bool >
Limits< float >
Limits< half >
Limits< int16_t >
Limits< int32_t >
Limits< int64_t >
Limits< int8_t >
Limits< uint16_t >
Limits< uint32_t >
Limits< uint64_t >
Limits< uint8_t >
Load (mlx::core)
Log
Log (mlx::core::detail)
Log (mlx::core)
Log10
Log10 (mlx::core::detail)
Log1p
Log1p (mlx::core::detail)
Log1p (mlx::core)
Log2
Log2 (mlx::core::detail)
LogAddExp
LogAddExp (mlx::core::detail)
LogAddExp (mlx::core)
LogicalAnd
LogicalAnd (mlx::core::detail)
LogicalAnd (mlx::core)
LogicalNot
LogicalNot (mlx::core::detail)
LogicalNot (mlx::core)
LogicalOr
LogicalOr (mlx::core::detail)
LogicalOr (mlx::core)
LoopAlignment (mlx::steel)
+
+
M
+
Matmul (mlx::core)
Matrix (MPS)
MatrixDescriptor (MPS)
MatrixMultiplication (MPS)
MatrixVectorMultiplication (MPS)
Max
Maximum
Maximum (mlx::core::detail)
Maximum (mlx::core)
MetalAllocator (mlx::core::metal)
Min
Minimum
Minimum (mlx::core::detail)
Minimum (mlx::core)
mlx_atomic
mlx_atomic< T, enable_if_t< is_metal_atomic< T > > >
MLXConvParams
MLXScaledDotProductAttentionParams
multi_iter (pocketfft::detail)
Multiply (mlx::core::detail)
Multiply (mlx::core)
Multiply
+
+
N
+
NaNEqual (mlx::core::detail)
NaNEqual
ndarr (pocketfft::detail)
Negative (mlx::core::detail)
Negative (mlx::core)
Negative
NodeNamer (mlx::core)
None
NotEqual (mlx::core::detail)
NotEqual (mlx::core)
NotEqual
NumberOfElements (mlx::core)
+
+
O
+
Or
+
+
P
+
Pad (mlx::core)
Partition (mlx::core)
pocketfft_c (pocketfft::detail)
pocketfft_r (pocketfft::detail)
Power (mlx::core::detail)
Power (mlx::core)
Power
Primitive (mlx::core)
PrintFormatter (mlx::core)
Prod
+
+
Q
+
QRF (mlx::core)
QuantizedMatmul (mlx::core)
+
+
R
+
RandomBits (mlx::core)
Reader (mlx::core::io)
BlockLoader::ReadVector (mlx::steel)
Reduce (mlx::core)
ReductionPlan (mlx::core)
Remainder (mlx::core::detail)
Remainder (mlx::core)
Remainder
Reshape (mlx::core)
rev_iter (pocketfft::detail)
rfftp (pocketfft::detail)
RightShift (mlx::core::detail)
RightShift
RMSNorm (mlx::core::fast)
RMSNormVJP (mlx::core::fast)
RoPE (mlx::core::fast)
Round (mlx::core::detail)
Round (mlx::core)
Round
Rsqrt (mlx::core::detail)
Rsqrt
+
+
S
+
ScaledDotProductAttention (mlx::core::fast)
Scan (mlx::core)
Scatter (mlx::core)
Scheduler (mlx::core::scheduler)
Select (mlx::core::detail)
Select (mlx::core)
Select
Sigmoid (mlx::core::detail)
Sigmoid (mlx::core)
Sigmoid
Sign (mlx::core::detail)
Sign (mlx::core)
Sign
simple_iter (pocketfft::detail)
Sin (mlx::core::detail)
Sin (mlx::core)
Sin
sincos_2pibyn (pocketfft::detail)
Sinh (mlx::core::detail)
Sinh (mlx::core)
Sinh
Slice (mlx::core)
SliceUpdate (mlx::core)
Softmax (mlx::core)
Sort (mlx::core)
Split (mlx::core)
Sqrt (mlx::core::detail)
Sqrt (mlx::core)
Sqrt
Square (mlx::core::detail)
Square (mlx::core)
Square
StopGradient (mlx::core)
Stream (mlx::core)
StreamContext (mlx::core)
StreamThread (mlx::core::scheduler)
Subtract (mlx::core::detail)
Subtract (mlx::core)
Subtract
Sum
SVD (mlx::core)
+
+
T
+
T_dcst23 (pocketfft::detail)
T_dcst4 (pocketfft::detail)
T_dct1 (pocketfft::detail)
T_dst1 (pocketfft::detail)
Tan (mlx::core::detail)
Tan (mlx::core)
Tan
Tanh (mlx::core::detail)
Tanh (mlx::core)
Tanh
thread_pool (pocketfft::detail::threading)
TransformAdd (mlx::steel)
TransformAxpby (mlx::steel)
TransformNone (mlx::steel)
Transpose (mlx::core)
TypeToDtype (mlx::core)
+
+
U
+
UnaryPrimitive (mlx::core)
Uniform (mlx::core)
util (pocketfft::detail)
+
+
V
+
Vector (MPS)
VectorDescriptor (MPS)
VLEN (pocketfft::detail)
VTYPE (pocketfft::detail)
+
+
W
+
Writer (mlx::core::io)
+
+
_
+
_MLX_BFloat16
_MLX_BFloat16 (mlx::core)
_MLX_Float16 (mlx::core)
_numeric_limits_impl< bfloat16_t > (metal)
+
+
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_abs-members.html b/docs/build/html/classmlx_1_1core_1_1_abs-members.html new file mode 100644 index 000000000..c4d6a962b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_abs-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Abs Member List
+
+
+ +

This is the complete list of members for mlx::core::Abs, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Abs(Stream stream)mlx::core::Absinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Absvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Absvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Absinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Absvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Absinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Absinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Absvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Absvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_abs.html b/docs/build/html/classmlx_1_1core_1_1_abs.html new file mode 100644 index 000000000..a78741814 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_abs.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Abs Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Abs Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Abs:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Abs (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Abs()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Abs::Abs (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Abs::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Abs::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Abs::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Abs::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Abs::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Abs::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Abs::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Abs::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_abs.png b/docs/build/html/classmlx_1_1core_1_1_abs.png new file mode 100644 index 000000000..ee6584fe9 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_abs.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_add-members.html b/docs/build/html/classmlx_1_1core_1_1_add-members.html new file mode 100644 index 000000000..ef563c9dd --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_add-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Add Member List
+
+
+ +

This is the complete list of members for mlx::core::Add, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Add(Stream stream)mlx::core::Addinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Addvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Addvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Addinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Addvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Addinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Addinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Addvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Addvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_add.html b/docs/build/html/classmlx_1_1core_1_1_add.html new file mode 100644 index 000000000..a5a7ad12b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_add.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Add Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Add Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Add:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Add (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Add()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Add::Add (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Add::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Add::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Add::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Add::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Add::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Add::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Add::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Add::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_add.png b/docs/build/html/classmlx_1_1core_1_1_add.png new file mode 100644 index 000000000..39bba292a Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_add.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_add_m_m-members.html b/docs/build/html/classmlx_1_1core_1_1_add_m_m-members.html new file mode 100644 index 000000000..f1b7abc47 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_add_m_m-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::AddMM Member List
+
+
+ +

This is the complete list of members for mlx::core::AddMM, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
AddMM(Stream stream, float alpha, float beta)mlx::core::AddMMinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::AddMMvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::AddMMvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::AddMMvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::AddMMinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::AddMMvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::AddMMvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_add_m_m.html b/docs/build/html/classmlx_1_1core_1_1_add_m_m.html new file mode 100644 index 000000000..16d98170e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_add_m_m.html @@ -0,0 +1,404 @@ + + + + + + + +MLX: mlx::core::AddMM Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::AddMM Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::AddMM:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 AddMM (Stream stream, float alpha, float beta)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ AddMM()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::AddMM::AddMM (Stream stream,
float alpha,
float beta )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AddMM::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AddMM::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::AddMM::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::AddMM::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AddMM::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::AddMM::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_add_m_m.png b/docs/build/html/classmlx_1_1core_1_1_add_m_m.png new file mode 100644 index 000000000..5e054780b Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_add_m_m.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arange-members.html b/docs/build/html/classmlx_1_1core_1_1_arange-members.html new file mode 100644 index 000000000..3d8271ae2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arange-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Arange Member List
+
+
+ +

This is the complete list of members for mlx::core::Arange, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Arange(Stream stream, double start, double stop, double step)mlx::core::Arangeinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Arangevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Arangevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Arangevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Arangeinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arange.html b/docs/build/html/classmlx_1_1core_1_1_arange.html new file mode 100644 index 000000000..936235fe3 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arange.html @@ -0,0 +1,332 @@ + + + + + + + +MLX: mlx::core::Arange Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Arange Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Arange:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Arange (Stream stream, double start, double stop, double step)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Arange()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::Arange::Arange (Stream stream,
double start,
double stop,
double step )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Arange::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Arange::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Arange::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Arange::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arange.png b/docs/build/html/classmlx_1_1core_1_1_arange.png new file mode 100644 index 000000000..b5f5fd908 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arange.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html new file mode 100644 index 000000000..233f33000 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcCos Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcCos, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcCos(Stream stream)mlx::core::ArcCosinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcCosvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcCosvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcCosinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcCosvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcCosinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcCosinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcCosvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcCosvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cos.html b/docs/build/html/classmlx_1_1core_1_1_arc_cos.html new file mode 100644 index 000000000..b465b208b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_cos.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcCos Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ArcCos Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcCos:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcCos (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcCos()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcCos::ArcCos (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcCos::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcCos::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcCos::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcCos::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcCos::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcCos::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcCos::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcCos::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cos.png b/docs/build/html/classmlx_1_1core_1_1_arc_cos.png new file mode 100644 index 000000000..2daeb8d48 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arc_cos.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html new file mode 100644 index 000000000..29d5b9722 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcCosh Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcCosh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcCosh(Stream stream)mlx::core::ArcCoshinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcCoshvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcCoshvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcCoshinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcCoshvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcCoshinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcCoshinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcCoshvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcCoshvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html b/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html new file mode 100644 index 000000000..54a745725 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcCosh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ArcCosh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcCosh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcCosh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcCosh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcCosh::ArcCosh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcCosh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcCosh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcCosh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcCosh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcCosh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcCosh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcCosh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcCosh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cosh.png b/docs/build/html/classmlx_1_1core_1_1_arc_cosh.png new file mode 100644 index 000000000..2242caeb7 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arc_cosh.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sin-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_sin-members.html new file mode 100644 index 000000000..a728c7141 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_sin-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcSin Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcSin, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcSin(Stream stream)mlx::core::ArcSininlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcSinvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcSinvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcSininlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcSinvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcSininlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcSininlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcSinvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcSinvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sin.html b/docs/build/html/classmlx_1_1core_1_1_arc_sin.html new file mode 100644 index 000000000..8f35ffa70 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_sin.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcSin Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ArcSin Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcSin:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcSin (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcSin()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcSin::ArcSin (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcSin::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcSin::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcSin::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcSin::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcSin::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcSin::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcSin::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcSin::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sin.png b/docs/build/html/classmlx_1_1core_1_1_arc_sin.png new file mode 100644 index 000000000..644ab73d9 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arc_sin.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sinh-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_sinh-members.html new file mode 100644 index 000000000..70e7e2c9a --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_sinh-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcSinh Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcSinh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcSinh(Stream stream)mlx::core::ArcSinhinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcSinhvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcSinhvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcSinhinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcSinhvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcSinhinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcSinhinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcSinhvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcSinhvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html b/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html new file mode 100644 index 000000000..3c5272e30 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcSinh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ArcSinh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcSinh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcSinh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcSinh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcSinh::ArcSinh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcSinh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcSinh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcSinh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcSinh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcSinh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcSinh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcSinh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcSinh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sinh.png b/docs/build/html/classmlx_1_1core_1_1_arc_sinh.png new file mode 100644 index 000000000..728cb98d3 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arc_sinh.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan-members.html new file mode 100644 index 000000000..2ccd4c936 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcTan Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcTan, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcTan(Stream stream)mlx::core::ArcTaninlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTanvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTanvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcTaninlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcTanvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcTaninlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcTaninlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcTanvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcTanvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan.html new file mode 100644 index 000000000..52bf53085 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcTan Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ArcTan Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcTan:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcTan (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcTan()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcTan::ArcTan (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTan::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTan::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcTan::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTan::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcTan::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcTan::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTan::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcTan::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan.png b/docs/build/html/classmlx_1_1core_1_1_arc_tan.png new file mode 100644 index 000000000..61bf8d991 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arc_tan.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html new file mode 100644 index 000000000..0fc8af254 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcTan2 Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcTan2, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcTan2(Stream stream)mlx::core::ArcTan2inlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTan2virtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTan2virtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcTan2inlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcTan2virtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcTan2inlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcTan2inlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcTan2virtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcTan2virtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html new file mode 100644 index 000000000..2e293234e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcTan2 Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ArcTan2 Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcTan2:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcTan2 (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcTan2()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcTan2::ArcTan2 (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTan2::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTan2::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcTan2::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTan2::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcTan2::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcTan2::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTan2::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcTan2::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan2.png b/docs/build/html/classmlx_1_1core_1_1_arc_tan2.png new file mode 100644 index 000000000..ff2449809 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arc_tan2.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html new file mode 100644 index 000000000..64caf5fc1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcTanh Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcTanh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcTanh(Stream stream)mlx::core::ArcTanhinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTanhvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTanhvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcTanhinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcTanhvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcTanhinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcTanhinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcTanhvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcTanhvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html b/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html new file mode 100644 index 000000000..9b55347cc --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcTanh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ArcTanh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcTanh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcTanh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcTanh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcTanh::ArcTanh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTanh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTanh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcTanh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTanh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcTanh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcTanh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTanh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcTanh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tanh.png b/docs/build/html/classmlx_1_1core_1_1_arc_tanh.png new file mode 100644 index 000000000..59f4ba4f4 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arc_tanh.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html b/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html new file mode 100644 index 000000000..b3ae98b73 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArgPartition Member List
+
+
+ +

This is the complete list of members for mlx::core::ArgPartition, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArgPartition(Stream stream, int kth, int axis)mlx::core::ArgPartitioninlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgPartitionvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgPartitionvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArgPartitionvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArgPartitioninlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArgPartitioninlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArgPartitionvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_partition.html b/docs/build/html/classmlx_1_1core_1_1_arg_partition.html new file mode 100644 index 000000000..ecdce6668 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_partition.html @@ -0,0 +1,391 @@ + + + + + + + +MLX: mlx::core::ArgPartition Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ArgPartition Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArgPartition:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArgPartition (Stream stream, int kth, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArgPartition()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::ArgPartition::ArgPartition (Stream stream,
int kth,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgPartition::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgPartition::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArgPartition::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArgPartition::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArgPartition::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArgPartition::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_partition.png b/docs/build/html/classmlx_1_1core_1_1_arg_partition.png new file mode 100644 index 000000000..8dcfb003d Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arg_partition.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_reduce-members.html b/docs/build/html/classmlx_1_1core_1_1_arg_reduce-members.html new file mode 100644 index 000000000..1be224e76 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_reduce-members.html @@ -0,0 +1,118 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArgReduce Member List
+
+
+ +

This is the complete list of members for mlx::core::ArgReduce, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ArgMax enum valuemlx::core::ArgReduce
ArgMin enum valuemlx::core::ArgReduce
ArgReduce(Stream stream, ReduceType reduce_type, int axis)mlx::core::ArgReduceinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgReducevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgReducevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArgReducevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArgReducevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArgReduceinlinevirtual
ReduceType enum namemlx::core::ArgReduce
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArgReducevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html b/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html new file mode 100644 index 000000000..1600965aa --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html @@ -0,0 +1,418 @@ + + + + + + + +MLX: mlx::core::ArgReduce Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Types | +Public Member Functions | +List of all members
+
mlx::core::ArgReduce Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArgReduce:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  ReduceType { ArgMin +, ArgMax + }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArgReduce (Stream stream, ReduceType reduce_type, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ ReduceType

+ +
+
+ + + +
Enumerator
ArgMin 
ArgMax 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ ArgReduce()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::ArgReduce::ArgReduce (Stream stream,
ReduceType reduce_type,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgReduce::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgReduce::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArgReduce::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArgReduce::output_shapes (const std::vector< array > & inputs)
+
+overridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArgReduce::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArgReduce::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_reduce.png b/docs/build/html/classmlx_1_1core_1_1_arg_reduce.png new file mode 100644 index 000000000..ac897a69d Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arg_reduce.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_sort-members.html b/docs/build/html/classmlx_1_1core_1_1_arg_sort-members.html new file mode 100644 index 000000000..b80b22dac --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_sort-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArgSort Member List
+
+
+ +

This is the complete list of members for mlx::core::ArgSort, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArgSort(Stream stream, int axis)mlx::core::ArgSortinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgSortvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgSortvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArgSortvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArgSortinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArgSortinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArgSortvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_sort.html b/docs/build/html/classmlx_1_1core_1_1_arg_sort.html new file mode 100644 index 000000000..510709ed7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_sort.html @@ -0,0 +1,386 @@ + + + + + + + +MLX: mlx::core::ArgSort Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ArgSort Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArgSort:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArgSort (Stream stream, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArgSort()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::ArgSort::ArgSort (Stream stream,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgSort::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgSort::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArgSort::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArgSort::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArgSort::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArgSort::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_sort.png b/docs/build/html/classmlx_1_1core_1_1_arg_sort.png new file mode 100644 index 000000000..523bf16a1 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_arg_sort.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html b/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html new file mode 100644 index 000000000..63e99012a --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::AsStrided Member List
+
+
+ +

This is the complete list of members for mlx::core::AsStrided, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)mlx::core::AsStridedinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::AsStridedvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::AsStridedvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::AsStridedvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::AsStridedvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::AsStridedinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::AsStridedvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_as_strided.html b/docs/build/html/classmlx_1_1core_1_1_as_strided.html new file mode 100644 index 000000000..d9002ed44 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_as_strided.html @@ -0,0 +1,413 @@ + + + + + + + +MLX: mlx::core::AsStrided Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::AsStrided Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::AsStrided:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 AsStrided (Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ AsStrided()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::AsStrided::AsStrided (Stream stream,
std::vector< int > shape,
std::vector< size_t > strides,
size_t offset )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AsStrided::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AsStrided::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::AsStrided::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AsStrided::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::AsStrided::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AsStrided::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_as_strided.png b/docs/build/html/classmlx_1_1core_1_1_as_strided.png new file mode 100644 index 000000000..7224d5d45 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_as_strided.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_as_type-members.html b/docs/build/html/classmlx_1_1core_1_1_as_type-members.html new file mode 100644 index 000000000..82abf852f --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_as_type-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::AsType Member List
+
+
+ +

This is the complete list of members for mlx::core::AsType, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
AsType(Stream stream, Dtype dtype)mlx::core::AsTypeinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::AsTypevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::AsTypevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::AsTypevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::AsTypevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::AsTypeinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::AsTypeinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::AsTypevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::AsTypevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_as_type.html b/docs/build/html/classmlx_1_1core_1_1_as_type.html new file mode 100644 index 000000000..00bb61d5a --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_as_type.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::AsType Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::AsType Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::AsType:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 AsType (Stream stream, Dtype dtype)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ AsType()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::AsType::AsType (Stream stream,
Dtype dtype )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AsType::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AsType::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::AsType::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AsType::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::AsType::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::AsType::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AsType::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::AsType::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_as_type.png b/docs/build/html/classmlx_1_1core_1_1_as_type.png new file mode 100644 index 000000000..4b919c285 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_as_type.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html new file mode 100644 index 000000000..f6c72e0c0 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html @@ -0,0 +1,121 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::BitwiseBinary Member List
+
+
+ +

This is the complete list of members for mlx::core::BitwiseBinary, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
And enum valuemlx::core::BitwiseBinary
BitwiseBinary(Stream stream, Op op)mlx::core::BitwiseBinaryinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::BitwiseBinaryvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::BitwiseBinaryvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::BitwiseBinaryvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
LeftShift enum valuemlx::core::BitwiseBinary
Op enum namemlx::core::BitwiseBinary
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
Or enum valuemlx::core::BitwiseBinary
output_shapes(const std::vector< array > &inputs) overridemlx::core::BitwiseBinaryinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::BitwiseBinaryvirtual
RightShift enum valuemlx::core::BitwiseBinary
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::BitwiseBinaryvirtual
Xor enum valuemlx::core::BitwiseBinary
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html new file mode 100644 index 000000000..d0a522c28 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html @@ -0,0 +1,422 @@ + + + + + + + +MLX: mlx::core::BitwiseBinary Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Types | +Public Member Functions | +List of all members
+
mlx::core::BitwiseBinary Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::BitwiseBinary:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  Op {
+  And +, Or +, Xor +, LeftShift +,
+  RightShift +
+ }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 BitwiseBinary (Stream stream, Op op)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ Op

+ +
+
+ + + + + + +
Enumerator
And 
Or 
Xor 
LeftShift 
RightShift 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ BitwiseBinary()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::BitwiseBinary::BitwiseBinary (Stream stream,
Op op )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BitwiseBinary::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BitwiseBinary::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::BitwiseBinary::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::BitwiseBinary::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::BitwiseBinary::print (std::ostream & os)
+
+overridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::BitwiseBinary::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.png b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.png new file mode 100644 index 000000000..0e74367fc Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html new file mode 100644 index 000000000..a2891b770 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::BlockMaskedMM Member List
+
+
+ +

This is the complete list of members for mlx::core::BlockMaskedMM, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
BlockMaskedMM(Stream stream, int block_size)mlx::core::BlockMaskedMMinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::BlockMaskedMMvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::BlockMaskedMMvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::BlockMaskedMMvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::BlockMaskedMMinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::BlockMaskedMMvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html new file mode 100644 index 000000000..f447b19b8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html @@ -0,0 +1,365 @@ + + + + + + + +MLX: mlx::core::BlockMaskedMM Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::BlockMaskedMM Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::BlockMaskedMM:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 BlockMaskedMM (Stream stream, int block_size)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ BlockMaskedMM()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::BlockMaskedMM::BlockMaskedMM (Stream stream,
int block_size )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BlockMaskedMM::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BlockMaskedMM::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::BlockMaskedMM::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::BlockMaskedMM::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::BlockMaskedMM::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.png b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.png new file mode 100644 index 000000000..8e5e7a8a4 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m-members.html b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m-members.html new file mode 100644 index 000000000..2cee30d11 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::BlockSparseMM Member List
+
+
+ +

This is the complete list of members for mlx::core::BlockSparseMM, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
BlockSparseMM(Stream stream)mlx::core::BlockSparseMMinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::BlockSparseMMvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::BlockSparseMMvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::BlockSparseMMinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::BlockSparseMMinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::BlockSparseMMvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.html b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.html new file mode 100644 index 000000000..3941e4915 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.html @@ -0,0 +1,361 @@ + + + + + + + +MLX: mlx::core::BlockSparseMM Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::BlockSparseMM Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::BlockSparseMM:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 BlockSparseMM (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ BlockSparseMM()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::BlockSparseMM::BlockSparseMM (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BlockSparseMM::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BlockSparseMM::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::BlockSparseMM::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::BlockSparseMM::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::BlockSparseMM::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.png b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.png new file mode 100644 index 000000000..a9ca56c7b Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html b/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html new file mode 100644 index 000000000..0f8b3cce0 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Broadcast Member List
+
+
+ +

This is the complete list of members for mlx::core::Broadcast, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Broadcast(Stream stream, const std::vector< int > &shape)mlx::core::Broadcastinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Broadcastvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Broadcastvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Broadcastvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Broadcastvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Broadcastinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Broadcastvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Broadcastvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_broadcast.html b/docs/build/html/classmlx_1_1core_1_1_broadcast.html new file mode 100644 index 000000000..4b455d2c2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_broadcast.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: mlx::core::Broadcast Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Broadcast Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Broadcast:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Broadcast (Stream stream, const std::vector< int > &shape)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Broadcast()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Broadcast::Broadcast (Stream stream,
const std::vector< int > & shape )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Broadcast::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Broadcast::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Broadcast::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Broadcast::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Broadcast::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Broadcast::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Broadcast::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_broadcast.png b/docs/build/html/classmlx_1_1core_1_1_broadcast.png new file mode 100644 index 000000000..080f3c455 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_broadcast.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_ceil-members.html b/docs/build/html/classmlx_1_1core_1_1_ceil-members.html new file mode 100644 index 000000000..2cce4f65c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_ceil-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Ceil Member List
+
+
+ +

This is the complete list of members for mlx::core::Ceil, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Ceil(Stream stream)mlx::core::Ceilinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Ceilvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Ceilvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Ceilinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Ceilvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Ceilinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Ceilinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Ceilvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Ceilvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_ceil.html b/docs/build/html/classmlx_1_1core_1_1_ceil.html new file mode 100644 index 000000000..4d11251f0 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_ceil.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Ceil Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Ceil Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Ceil:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Ceil (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Ceil()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Ceil::Ceil (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Ceil::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Ceil::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Ceil::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Ceil::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Ceil::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Ceil::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Ceil::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Ceil::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_ceil.png b/docs/build/html/classmlx_1_1core_1_1_ceil.png new file mode 100644 index 000000000..7894fb3ec Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_ceil.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_compiled-members.html b/docs/build/html/classmlx_1_1core_1_1_compiled-members.html new file mode 100644 index 000000000..bf0ad0373 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_compiled-members.html @@ -0,0 +1,108 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Compiled Member List
+
+
+ +

This is the complete list of members for mlx::core::Compiled, including all inherited members.

+ + + + + + + + + + + + + + + + + + + +
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)mlx::core::Compiledexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Compiledvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Compiledvirtual
is_equivalent(const Primitive &other) const overridemlx::core::Compiledvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Compiledvirtual
lib_name() constmlx::core::Compiledinline
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Compiledvirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Compiledvirtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Compiledvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Compiledvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_compiled.html b/docs/build/html/classmlx_1_1core_1_1_compiled.html new file mode 100644 index 000000000..78b508b6e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_compiled.html @@ -0,0 +1,493 @@ + + + + + + + +MLX: mlx::core::Compiled Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Compiled Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Compiled:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Compiled (Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::string lib_name () const
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Compiled()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
mlx::core::Compiled::Compiled (Stream stream,
std::vector< array > inputs,
std::vector< array > outputs,
std::vector< array > tape,
std::unordered_set< uintptr_t > constant_ids )
+
+explicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Compiled::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Compiled::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Compiled::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Compiled::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ lib_name()

+ +
+
+ + + + + +
+ + + + + + + +
std::string mlx::core::Compiled::lib_name () const
+
+inline
+
+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Compiled::output_shapes (const std::vector< array > & inputs)
+
+overridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Compiled::print (std::ostream & os)
+
+overridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Compiled::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Compiled::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_compiled.png b/docs/build/html/classmlx_1_1core_1_1_compiled.png new file mode 100644 index 000000000..4f12eb20e Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_compiled.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html b/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html new file mode 100644 index 000000000..9847df5c1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Concatenate Member List
+
+
+ +

This is the complete list of members for mlx::core::Concatenate, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Concatenate(Stream stream, int axis)mlx::core::Concatenateinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Concatenatevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Concatenatevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Concatenatevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Concatenatevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Concatenateinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Concatenatevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Concatenatevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_concatenate.html b/docs/build/html/classmlx_1_1core_1_1_concatenate.html new file mode 100644 index 000000000..b96f54f64 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_concatenate.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: mlx::core::Concatenate Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Concatenate Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Concatenate:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Concatenate (Stream stream, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Concatenate()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Concatenate::Concatenate (Stream stream,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Concatenate::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Concatenate::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Concatenate::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Concatenate::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Concatenate::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Concatenate::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Concatenate::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_concatenate.png b/docs/build/html/classmlx_1_1core_1_1_concatenate.png new file mode 100644 index 000000000..340462145 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_concatenate.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html b/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html new file mode 100644 index 000000000..3439e670e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Conjugate Member List
+
+
+ +

This is the complete list of members for mlx::core::Conjugate, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Conjugate(Stream stream)mlx::core::Conjugateinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Conjugatevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Conjugatevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Conjugateinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Conjugateinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Conjugateinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Conjugatevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_conjugate.html b/docs/build/html/classmlx_1_1core_1_1_conjugate.html new file mode 100644 index 000000000..f7f0de1cf --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_conjugate.html @@ -0,0 +1,382 @@ + + + + + + + +MLX: mlx::core::Conjugate Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Conjugate Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Conjugate:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Conjugate (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Conjugate()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Conjugate::Conjugate (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Conjugate::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Conjugate::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Conjugate::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Conjugate::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Conjugate::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Conjugate::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_conjugate.png b/docs/build/html/classmlx_1_1core_1_1_conjugate.png new file mode 100644 index 000000000..08be44bd1 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_conjugate.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_convolution-members.html b/docs/build/html/classmlx_1_1core_1_1_convolution-members.html new file mode 100644 index 000000000..c09690551 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_convolution-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Convolution Member List
+
+
+ +

This is the complete list of members for mlx::core::Convolution, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)mlx::core::Convolutioninlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Convolutionvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Convolutionvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Convolutionvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Convolutioninlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Convolutionvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_convolution.html b/docs/build/html/classmlx_1_1core_1_1_convolution.html new file mode 100644 index 000000000..ac64adff4 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_convolution.html @@ -0,0 +1,390 @@ + + + + + + + +MLX: mlx::core::Convolution Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Convolution Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Convolution:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Convolution (Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Convolution()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
mlx::core::Convolution::Convolution (Stream stream,
const std::vector< int > & kernel_strides,
const std::vector< int > & padding,
const std::vector< int > & kernel_dilation,
const std::vector< int > & input_dilation,
const int groups = 1,
const bool flip = false )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Convolution::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Convolution::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Convolution::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Convolution::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Convolution::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_convolution.png b/docs/build/html/classmlx_1_1core_1_1_convolution.png new file mode 100644 index 000000000..853ab7ab4 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_convolution.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_copy-members.html b/docs/build/html/classmlx_1_1core_1_1_copy-members.html new file mode 100644 index 000000000..3cb46c818 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_copy-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Copy Member List
+
+
+ +

This is the complete list of members for mlx::core::Copy, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Copy(Stream stream)mlx::core::Copyinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Copyvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Copyvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Copyinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Copyvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Copyinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Copyinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Copyvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Copyvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_copy.html b/docs/build/html/classmlx_1_1core_1_1_copy.html new file mode 100644 index 000000000..f233d09c2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_copy.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Copy Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Copy Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Copy:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Copy (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Copy()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Copy::Copy (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Copy::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Copy::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Copy::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Copy::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Copy::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Copy::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Copy::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Copy::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_copy.png b/docs/build/html/classmlx_1_1core_1_1_copy.png new file mode 100644 index 000000000..2f4f36d04 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_copy.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_cos-members.html b/docs/build/html/classmlx_1_1core_1_1_cos-members.html new file mode 100644 index 000000000..f07b0309e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_cos-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Cos Member List
+
+
+ +

This is the complete list of members for mlx::core::Cos, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Cos(Stream stream)mlx::core::Cosinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Cosvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Cosvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Cosinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Cosvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Cosinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Cosinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Cosvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Cosvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_cos.html b/docs/build/html/classmlx_1_1core_1_1_cos.html new file mode 100644 index 000000000..1d2e6907d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_cos.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Cos Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Cos Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Cos:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Cos (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Cos()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Cos::Cos (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Cos::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Cos::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Cos::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Cos::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Cos::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Cos::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Cos::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Cos::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_cos.png b/docs/build/html/classmlx_1_1core_1_1_cos.png new file mode 100644 index 000000000..4724c19a7 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_cos.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_cosh-members.html b/docs/build/html/classmlx_1_1core_1_1_cosh-members.html new file mode 100644 index 000000000..4669e2c0f --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_cosh-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Cosh Member List
+
+
+ +

This is the complete list of members for mlx::core::Cosh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Cosh(Stream stream)mlx::core::Coshinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Coshvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Coshvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Coshinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Coshvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Coshinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Coshinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Coshvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Coshvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_cosh.html b/docs/build/html/classmlx_1_1core_1_1_cosh.html new file mode 100644 index 000000000..385f8d6d8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_cosh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Cosh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Cosh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Cosh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Cosh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Cosh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Cosh::Cosh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Cosh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Cosh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Cosh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Cosh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Cosh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Cosh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Cosh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Cosh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_cosh.png b/docs/build/html/classmlx_1_1core_1_1_cosh.png new file mode 100644 index 000000000..69fffddab Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_cosh.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p-members.html b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p-members.html new file mode 100644 index 000000000..eee5452e6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::CustomVJP Member List
+
+
+ +

This is the complete list of members for mlx::core::CustomVJP, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
CustomVJP(Stream stream, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun)mlx::core::CustomVJPinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::CustomVJPvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::CustomVJPvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::CustomVJPinlinevirtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::CustomVJPvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.html b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.html new file mode 100644 index 000000000..8ed54c2a1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.html @@ -0,0 +1,320 @@ + + + + + + + +MLX: mlx::core::CustomVJP Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::CustomVJP Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::CustomVJP:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 CustomVJP (Stream stream, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ CustomVJP()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::CustomVJP::CustomVJP (Stream stream,
std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::CustomVJP::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::CustomVJP::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::CustomVJP::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::CustomVJP::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.png b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.png new file mode 100644 index 000000000..32bf6e7e9 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_depends-members.html b/docs/build/html/classmlx_1_1core_1_1_depends-members.html new file mode 100644 index 000000000..8cb2e5547 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_depends-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Depends Member List
+
+
+ +

This is the complete list of members for mlx::core::Depends, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
Depends(Stream stream)mlx::core::Dependsinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Dependsvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Dependsvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Dependsinlinevirtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Dependsvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_depends.html b/docs/build/html/classmlx_1_1core_1_1_depends.html new file mode 100644 index 000000000..45bef7a88 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_depends.html @@ -0,0 +1,316 @@ + + + + + + + +MLX: mlx::core::Depends Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Depends Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Depends:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Depends (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Depends()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Depends::Depends (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Depends::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Depends::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Depends::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Depends::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_depends.png b/docs/build/html/classmlx_1_1core_1_1_depends.png new file mode 100644 index 000000000..8c1a6319a Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_depends.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html b/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html new file mode 100644 index 000000000..bd2f77317 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::DivMod Member List
+
+
+ +

This is the complete list of members for mlx::core::DivMod, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
DivMod(Stream stream)mlx::core::DivModinlineexplicit
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::DivModvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::DivModvirtual
is_equivalent(const Primitive &other) const overridemlx::core::DivModinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::DivModvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::DivModinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::DivModinlinevirtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::DivModvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::DivModvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_div_mod.html b/docs/build/html/classmlx_1_1core_1_1_div_mod.html new file mode 100644 index 000000000..d4455f3ed --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_div_mod.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::DivMod Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::DivMod Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::DivMod:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 DivMod (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ DivMod()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::DivMod::DivMod (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::DivMod::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::DivMod::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::DivMod::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::DivMod::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::DivMod::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::DivMod::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::DivMod::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::DivMod::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_div_mod.png b/docs/build/html/classmlx_1_1core_1_1_div_mod.png new file mode 100644 index 000000000..11583cfa1 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_div_mod.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_divide-members.html b/docs/build/html/classmlx_1_1core_1_1_divide-members.html new file mode 100644 index 000000000..e2f1b1e2e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_divide-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Divide Member List
+
+
+ +

This is the complete list of members for mlx::core::Divide, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
Divide(Stream stream)mlx::core::Divideinlineexplicit
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Dividevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Dividevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Divideinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Dividevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Divideinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Divideinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Dividevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Dividevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_divide.html b/docs/build/html/classmlx_1_1core_1_1_divide.html new file mode 100644 index 000000000..be3840b28 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_divide.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Divide Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Divide Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Divide:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Divide (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Divide()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Divide::Divide (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Divide::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Divide::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Divide::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Divide::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Divide::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Divide::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Divide::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Divide::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_divide.png b/docs/build/html/classmlx_1_1core_1_1_divide.png new file mode 100644 index 000000000..f3946b16d Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_divide.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_equal-members.html b/docs/build/html/classmlx_1_1core_1_1_equal-members.html new file mode 100644 index 000000000..a47714196 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_equal-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Equal Member List
+
+
+ +

This is the complete list of members for mlx::core::Equal, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
Equal(Stream stream, bool equal_nan=false)mlx::core::Equalinlineexplicit
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Equalvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Equalvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Equalinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Equalvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Equalinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Equalinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Equalvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Equalvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_equal.html b/docs/build/html/classmlx_1_1core_1_1_equal.html new file mode 100644 index 000000000..0ed4a5122 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_equal.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::Equal Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Equal Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Equal:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Equal (Stream stream, bool equal_nan=false)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Equal()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Equal::Equal (Stream stream,
bool equal_nan = false )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Equal::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Equal::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Equal::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Equal::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Equal::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Equal::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Equal::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Equal::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_equal.png b/docs/build/html/classmlx_1_1core_1_1_equal.png new file mode 100644 index 000000000..7c77a8836 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_equal.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_erf-members.html b/docs/build/html/classmlx_1_1core_1_1_erf-members.html new file mode 100644 index 000000000..896816975 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_erf-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Erf Member List
+
+
+ +

This is the complete list of members for mlx::core::Erf, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
Erf(Stream stream)mlx::core::Erfinlineexplicit
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Erfvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Erfvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Erfinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Erfvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Erfinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Erfinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Erfvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Erfvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_erf.html b/docs/build/html/classmlx_1_1core_1_1_erf.html new file mode 100644 index 000000000..07e13bab9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_erf.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Erf Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Erf Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Erf:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Erf (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Erf()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Erf::Erf (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Erf::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Erf::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Erf::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Erf::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Erf::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Erf::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Erf::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Erf::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_erf.png b/docs/build/html/classmlx_1_1core_1_1_erf.png new file mode 100644 index 000000000..d21c1648a Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_erf.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html b/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html new file mode 100644 index 000000000..1b4e58ece --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ErfInv Member List
+
+
+ +

This is the complete list of members for mlx::core::ErfInv, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
ErfInv(Stream stream)mlx::core::ErfInvinlineexplicit
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ErfInvvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ErfInvvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ErfInvinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ErfInvvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ErfInvinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ErfInvinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ErfInvvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ErfInvvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_erf_inv.html b/docs/build/html/classmlx_1_1core_1_1_erf_inv.html new file mode 100644 index 000000000..92b8c6bf2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_erf_inv.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ErfInv Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::ErfInv Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ErfInv:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ErfInv (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ErfInv()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ErfInv::ErfInv (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ErfInv::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ErfInv::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ErfInv::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ErfInv::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ErfInv::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ErfInv::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ErfInv::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ErfInv::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_erf_inv.png b/docs/build/html/classmlx_1_1core_1_1_erf_inv.png new file mode 100644 index 000000000..2ed64aaf6 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_erf_inv.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_event-members.html b/docs/build/html/classmlx_1_1core_1_1_event-members.html new file mode 100644 index 000000000..02d8245bf --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_event-members.html @@ -0,0 +1,99 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Event Member List
+
+
+ +

This is the complete list of members for mlx::core::Event, including all inherited members.

+ + + + + + + + + + +
Event()mlx::core::Eventinline
Event(const Stream &steam)mlx::core::Event
raw_event()mlx::core::Eventinline
set_value(uint64_t v)mlx::core::Eventinline
signal()mlx::core::Event
stream()mlx::core::Eventinline
valid()mlx::core::Eventinline
value()mlx::core::Eventinline
wait()mlx::core::Event
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_event.html b/docs/build/html/classmlx_1_1core_1_1_event.html new file mode 100644 index 000000000..0b8a08aa0 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_event.html @@ -0,0 +1,320 @@ + + + + + + + +MLX: mlx::core::Event Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Event Class Reference
+
+
+ +

#include <event.h>

+ + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Event ()
 
 Event (const Stream &steam)
 
void wait ()
 
void signal ()
 
bool valid ()
 
uint64_t value ()
 
void set_value (uint64_t v)
 
const Streamstream ()
 
const std::shared_ptr< void > & raw_event ()
 
+

Constructor & Destructor Documentation

+ +

◆ Event() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Event::Event ()
+
+inline
+
+ +
+
+ +

◆ Event() [2/2]

+ +
+
+ + + + + + + +
mlx::core::Event::Event (const Stream & steam)
+
+ +
+
+

Member Function Documentation

+ +

◆ raw_event()

+ +
+
+ + + + + +
+ + + + + + + +
const std::shared_ptr< void > & mlx::core::Event::raw_event ()
+
+inline
+
+ +
+
+ +

◆ set_value()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Event::set_value (uint64_t v)
+
+inline
+
+ +
+
+ +

◆ signal()

+ +
+
+ + + + + + + +
void mlx::core::Event::signal ()
+
+ +
+
+ +

◆ stream()

+ +
+
+ + + + + +
+ + + + + + + +
const Stream & mlx::core::Event::stream ()
+
+inline
+
+ +
+
+ +

◆ valid()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Event::valid ()
+
+inline
+
+ +
+
+ +

◆ value()

+ +
+
+ + + + + +
+ + + + + + + +
uint64_t mlx::core::Event::value ()
+
+inline
+
+ +
+
+ +

◆ wait()

+ +
+
+ + + + + + + +
void mlx::core::Event::wait ()
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_exp-members.html b/docs/build/html/classmlx_1_1core_1_1_exp-members.html new file mode 100644 index 000000000..ff9c86910 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_exp-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Exp Member List
+
+
+ +

This is the complete list of members for mlx::core::Exp, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Expvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Expvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Exp(Stream stream)mlx::core::Expinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Expinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Expvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Expinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Expinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Expvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Expvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_exp.html b/docs/build/html/classmlx_1_1core_1_1_exp.html new file mode 100644 index 000000000..7a5acdfdf --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_exp.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Exp Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Exp Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Exp:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Exp (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Exp()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Exp::Exp (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Exp::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Exp::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Exp::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Exp::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Exp::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Exp::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Exp::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Exp::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_exp.png b/docs/build/html/classmlx_1_1core_1_1_exp.png new file mode 100644 index 000000000..5072482be Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_exp.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_expm1-members.html b/docs/build/html/classmlx_1_1core_1_1_expm1-members.html new file mode 100644 index 000000000..2fff8d1c1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_expm1-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Expm1 Member List
+
+
+ +

This is the complete list of members for mlx::core::Expm1, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Expm1virtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Expm1virtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Expm1(Stream stream)mlx::core::Expm1inlineexplicit
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Expm1virtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Expm1inlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Expm1inlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Expm1virtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Expm1virtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_expm1.html b/docs/build/html/classmlx_1_1core_1_1_expm1.html new file mode 100644 index 000000000..7254a12ad --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_expm1.html @@ -0,0 +1,434 @@ + + + + + + + +MLX: mlx::core::Expm1 Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Expm1 Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Expm1:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Expm1 (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Expm1()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Expm1::Expm1 (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Expm1::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Expm1::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Expm1::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Expm1::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Expm1::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Expm1::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Expm1::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_expm1.png b/docs/build/html/classmlx_1_1core_1_1_expm1.png new file mode 100644 index 000000000..da566929f Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_expm1.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html b/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html new file mode 100644 index 000000000..1b6c9e55c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::FFT Member List
+
+
+ +

This is the complete list of members for mlx::core::FFT, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::FFTvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::FFTvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)mlx::core::FFTinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::FFTvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::FFTvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::FFTinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::FFTvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::FFTvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_f_f_t.html b/docs/build/html/classmlx_1_1core_1_1_f_f_t.html new file mode 100644 index 000000000..de279f0ed --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_f_f_t.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::FFT Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::FFT Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::FFT:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 FFT (Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ FFT()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::FFT::FFT (Stream stream,
const std::vector< size_t > & axes,
bool inverse,
bool real )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::FFT::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::FFT::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::FFT::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::FFT::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::FFT::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::FFT::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::FFT::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_f_f_t.png b/docs/build/html/classmlx_1_1core_1_1_f_f_t.png new file mode 100644 index 000000000..aa05d735a Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_f_f_t.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_floor-members.html b/docs/build/html/classmlx_1_1core_1_1_floor-members.html new file mode 100644 index 000000000..bac2e36a5 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_floor-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Floor Member List
+
+
+ +

This is the complete list of members for mlx::core::Floor, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Floorvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Floorvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Floor(Stream stream)mlx::core::Floorinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Floorinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Floorvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Floorinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Floorinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Floorvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Floorvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_floor.html b/docs/build/html/classmlx_1_1core_1_1_floor.html new file mode 100644 index 000000000..60f7e423d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_floor.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Floor Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Floor Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Floor:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Floor (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Floor()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Floor::Floor (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Floor::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Floor::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Floor::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Floor::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Floor::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Floor::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Floor::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Floor::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_floor.png b/docs/build/html/classmlx_1_1core_1_1_floor.png new file mode 100644 index 000000000..2b602e649 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_floor.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_full-members.html b/docs/build/html/classmlx_1_1core_1_1_full-members.html new file mode 100644 index 000000000..4b4ee2e94 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_full-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Full Member List
+
+
+ +

This is the complete list of members for mlx::core::Full, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Fullvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Fullvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Full(Stream stream)mlx::core::Fullinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Fullinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Fullvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Fullinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Fullvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Fullvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_full.html b/docs/build/html/classmlx_1_1core_1_1_full.html new file mode 100644 index 000000000..7ad00b10e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_full.html @@ -0,0 +1,433 @@ + + + + + + + +MLX: mlx::core::Full Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Full Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Full:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Full (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Full()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Full::Full (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Full::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Full::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Full::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Full::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Full::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Full::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Full::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_full.png b/docs/build/html/classmlx_1_1core_1_1_full.png new file mode 100644 index 000000000..51e255780 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_full.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_gather-members.html b/docs/build/html/classmlx_1_1core_1_1_gather-members.html new file mode 100644 index 000000000..fdd2df31a --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_gather-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Gather Member List
+
+
+ +

This is the complete list of members for mlx::core::Gather, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Gathervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Gathervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)mlx::core::Gatherinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Gathervirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Gathervirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Gatherinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Gathervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Gathervirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_gather.html b/docs/build/html/classmlx_1_1core_1_1_gather.html new file mode 100644 index 000000000..09cb2f1de --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_gather.html @@ -0,0 +1,442 @@ + + + + + + + +MLX: mlx::core::Gather Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Gather Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Gather:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Gather (Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Gather()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Gather::Gather (Stream stream,
const std::vector< int > & axes,
const std::vector< int > & slice_sizes )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Gather::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Gather::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Gather::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Gather::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Gather::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Gather::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Gather::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_gather.png b/docs/build/html/classmlx_1_1core_1_1_gather.png new file mode 100644 index 000000000..7840ba3e8 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_gather.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_greater-members.html b/docs/build/html/classmlx_1_1core_1_1_greater-members.html new file mode 100644 index 000000000..a88fdac80 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_greater-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Greater Member List
+
+
+ +

This is the complete list of members for mlx::core::Greater, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Greatervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Greatervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Greater(Stream stream)mlx::core::Greaterinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Greaterinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Greatervirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Greaterinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Greaterinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Greatervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Greatervirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_greater.html b/docs/build/html/classmlx_1_1core_1_1_greater.html new file mode 100644 index 000000000..4e738bd6d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_greater.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Greater Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Greater Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Greater:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Greater (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Greater()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Greater::Greater (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Greater::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Greater::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Greater::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Greater::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Greater::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Greater::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Greater::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Greater::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_greater.png b/docs/build/html/classmlx_1_1core_1_1_greater.png new file mode 100644 index 000000000..ed485df38 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_greater.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_greater_equal-members.html b/docs/build/html/classmlx_1_1core_1_1_greater_equal-members.html new file mode 100644 index 000000000..1fac8ce64 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_greater_equal-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::GreaterEqual Member List
+
+
+ +

This is the complete list of members for mlx::core::GreaterEqual, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::GreaterEqualvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::GreaterEqualvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
GreaterEqual(Stream stream)mlx::core::GreaterEqualinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::GreaterEqualinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::GreaterEqualvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::GreaterEqualinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::GreaterEqualinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::GreaterEqualvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::GreaterEqualvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_greater_equal.html b/docs/build/html/classmlx_1_1core_1_1_greater_equal.html new file mode 100644 index 000000000..16aec8e67 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_greater_equal.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::GreaterEqual Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::GreaterEqual Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::GreaterEqual:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 GreaterEqual (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ GreaterEqual()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::GreaterEqual::GreaterEqual (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::GreaterEqual::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::GreaterEqual::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::GreaterEqual::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::GreaterEqual::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::GreaterEqual::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::GreaterEqual::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::GreaterEqual::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::GreaterEqual::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_greater_equal.png b/docs/build/html/classmlx_1_1core_1_1_greater_equal.png new file mode 100644 index 000000000..3b6862e59 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_greater_equal.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_inverse-members.html b/docs/build/html/classmlx_1_1core_1_1_inverse-members.html new file mode 100644 index 000000000..1fb2c948d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_inverse-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Inverse Member List
+
+
+ +

This is the complete list of members for mlx::core::Inverse, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &output) overridemlx::core::Inversevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &output) overridemlx::core::Inversevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Inverse(Stream stream)mlx::core::Inverseinlineexplicit
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Inverseinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Inversevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_inverse.html b/docs/build/html/classmlx_1_1core_1_1_inverse.html new file mode 100644 index 000000000..ec23a1051 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_inverse.html @@ -0,0 +1,323 @@ + + + + + + + +MLX: mlx::core::Inverse Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Inverse Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Inverse:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Inverse (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &output) override
 
void eval_gpu (const std::vector< array > &inputs, array &output) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Inverse()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Inverse::Inverse (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Inverse::eval_cpu (const std::vector< array > & inputs,
array & output )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Inverse::eval_gpu (const std::vector< array > & inputs,
array & output )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Inverse::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Inverse::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_inverse.png b/docs/build/html/classmlx_1_1core_1_1_inverse.png new file mode 100644 index 000000000..c59ec21c0 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_inverse.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_less-members.html b/docs/build/html/classmlx_1_1core_1_1_less-members.html new file mode 100644 index 000000000..e3f748124 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_less-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Less Member List
+
+
+ +

This is the complete list of members for mlx::core::Less, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Lessvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Lessvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Lessinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Lessvirtual
Less(Stream stream)mlx::core::Lessinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Lessinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Lessinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Lessvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Lessvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_less.html b/docs/build/html/classmlx_1_1core_1_1_less.html new file mode 100644 index 000000000..076e7bfb9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_less.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Less Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Less Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Less:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Less (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Less()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Less::Less (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Less::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Less::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Less::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Less::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Less::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Less::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Less::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Less::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_less.png b/docs/build/html/classmlx_1_1core_1_1_less.png new file mode 100644 index 000000000..5fde4667d Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_less.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_less_equal-members.html b/docs/build/html/classmlx_1_1core_1_1_less_equal-members.html new file mode 100644 index 000000000..bfa8083f6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_less_equal-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LessEqual Member List
+
+
+ +

This is the complete list of members for mlx::core::LessEqual, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LessEqualvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LessEqualvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LessEqualinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LessEqualvirtual
LessEqual(Stream stream)mlx::core::LessEqualinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LessEqualinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LessEqualinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LessEqualvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LessEqualvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_less_equal.html b/docs/build/html/classmlx_1_1core_1_1_less_equal.html new file mode 100644 index 000000000..ea0474c76 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_less_equal.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LessEqual Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::LessEqual Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LessEqual:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LessEqual (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LessEqual()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LessEqual::LessEqual (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LessEqual::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LessEqual::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LessEqual::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LessEqual::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LessEqual::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LessEqual::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LessEqual::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LessEqual::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_less_equal.png b/docs/build/html/classmlx_1_1core_1_1_less_equal.png new file mode 100644 index 000000000..861844408 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_less_equal.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_load-members.html b/docs/build/html/classmlx_1_1core_1_1_load-members.html new file mode 100644 index 000000000..9494f3cdd --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_load-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Load Member List
+
+
+ +

This is the complete list of members for mlx::core::Load, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Loadvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Loadvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)mlx::core::Loadinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Loadinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_load.html b/docs/build/html/classmlx_1_1core_1_1_load.html new file mode 100644 index 000000000..f37a2c781 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_load.html @@ -0,0 +1,303 @@ + + + + + + + +MLX: mlx::core::Load Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Load Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Load:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Load (Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Load()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::Load::Load (Stream stream,
std::shared_ptr< io::Reader > reader,
size_t offset,
bool swap_endianness = false )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Load::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Load::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Load::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_load.png b/docs/build/html/classmlx_1_1core_1_1_load.png new file mode 100644 index 000000000..cb43b85d4 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_load.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_log-members.html b/docs/build/html/classmlx_1_1core_1_1_log-members.html new file mode 100644 index 000000000..b66e45b93 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log-members.html @@ -0,0 +1,119 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Log Member List
+
+
+ +

This is the complete list of members for mlx::core::Log, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Base enum namemlx::core::Log
device()mlx::core::Primitiveinline
e enum valuemlx::core::Log
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Logvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Logvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Loginlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Logvirtual
Log(Stream stream, Base base)mlx::core::Loginlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Loginlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Loginlinevirtual
stream()mlx::core::Primitiveinline
ten enum valuemlx::core::Log
two enum valuemlx::core::Log
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Logvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Logvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log.html b/docs/build/html/classmlx_1_1core_1_1_log.html new file mode 100644 index 000000000..0728bb815 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log.html @@ -0,0 +1,496 @@ + + + + + + + +MLX: mlx::core::Log Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Types | +Public Member Functions | +List of all members
+
mlx::core::Log Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Log:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  Base { two +, ten +, e + }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Log (Stream stream, Base base)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ Base

+ +
+
+ + + + +
enum mlx::core::Log::Base
+
+ + + + +
Enumerator
two 
ten 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ Log()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Log::Log (Stream stream,
Base base )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Log::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Log::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Log::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Log::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Log::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Log::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Log::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Log::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log.png b/docs/build/html/classmlx_1_1core_1_1_log.png new file mode 100644 index 000000000..cc9ba7d85 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_log.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_log1p-members.html b/docs/build/html/classmlx_1_1core_1_1_log1p-members.html new file mode 100644 index 000000000..e9d9678a2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log1p-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Log1p Member List
+
+
+ +

This is the complete list of members for mlx::core::Log1p, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Log1pvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Log1pvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Log1pvirtual
Log1p(Stream stream)mlx::core::Log1pinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Log1pinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Log1pinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Log1pvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Log1pvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log1p.html b/docs/build/html/classmlx_1_1core_1_1_log1p.html new file mode 100644 index 000000000..0eeaf6c9e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log1p.html @@ -0,0 +1,434 @@ + + + + + + + +MLX: mlx::core::Log1p Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Log1p Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Log1p:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Log1p (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Log1p()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Log1p::Log1p (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Log1p::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Log1p::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Log1p::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Log1p::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Log1p::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Log1p::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Log1p::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log1p.png b/docs/build/html/classmlx_1_1core_1_1_log1p.png new file mode 100644 index 000000000..fc2853680 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_log1p.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html b/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html new file mode 100644 index 000000000..3c05b01b8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LogAddExp Member List
+
+
+ +

This is the complete list of members for mlx::core::LogAddExp, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogAddExpvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogAddExpvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LogAddExpinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LogAddExpvirtual
LogAddExp(Stream stream)mlx::core::LogAddExpinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LogAddExpinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LogAddExpinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LogAddExpvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LogAddExpvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html b/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html new file mode 100644 index 000000000..1cc6326ec --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LogAddExp Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::LogAddExp Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LogAddExp:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LogAddExp (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LogAddExp()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LogAddExp::LogAddExp (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogAddExp::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogAddExp::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LogAddExp::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogAddExp::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LogAddExp::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LogAddExp::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogAddExp::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LogAddExp::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log_add_exp.png b/docs/build/html/classmlx_1_1core_1_1_log_add_exp.png new file mode 100644 index 000000000..28cb8ab01 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_log_add_exp.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html b/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html new file mode 100644 index 000000000..b96dfe648 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LogicalAnd Member List
+
+
+ +

This is the complete list of members for mlx::core::LogicalAnd, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalAndvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalAndvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LogicalAndinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LogicalAndvirtual
LogicalAnd(Stream stream)mlx::core::LogicalAndinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LogicalAndinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LogicalAndinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LogicalAndvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LogicalAndvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_and.html b/docs/build/html/classmlx_1_1core_1_1_logical_and.html new file mode 100644 index 000000000..1ea446dab --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_and.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LogicalAnd Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::LogicalAnd Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LogicalAnd:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LogicalAnd (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LogicalAnd()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LogicalAnd::LogicalAnd (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalAnd::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalAnd::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LogicalAnd::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalAnd::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LogicalAnd::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LogicalAnd::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalAnd::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LogicalAnd::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_and.png b/docs/build/html/classmlx_1_1core_1_1_logical_and.png new file mode 100644 index 000000000..65d5f0bb4 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_logical_and.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html b/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html new file mode 100644 index 000000000..1b749bde7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LogicalNot Member List
+
+
+ +

This is the complete list of members for mlx::core::LogicalNot, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalNotvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalNotvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LogicalNotinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LogicalNotvirtual
LogicalNot(Stream stream)mlx::core::LogicalNotinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LogicalNotinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LogicalNotinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LogicalNotvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LogicalNotvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_not.html b/docs/build/html/classmlx_1_1core_1_1_logical_not.html new file mode 100644 index 000000000..a4650332d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_not.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LogicalNot Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::LogicalNot Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LogicalNot:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LogicalNot (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LogicalNot()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LogicalNot::LogicalNot (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalNot::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalNot::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LogicalNot::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalNot::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LogicalNot::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LogicalNot::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalNot::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LogicalNot::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_not.png b/docs/build/html/classmlx_1_1core_1_1_logical_not.png new file mode 100644 index 000000000..51f652755 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_logical_not.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_or-members.html b/docs/build/html/classmlx_1_1core_1_1_logical_or-members.html new file mode 100644 index 000000000..dc3a25f10 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_or-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LogicalOr Member List
+
+
+ +

This is the complete list of members for mlx::core::LogicalOr, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalOrvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalOrvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LogicalOrinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LogicalOrvirtual
LogicalOr(Stream stream)mlx::core::LogicalOrinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LogicalOrinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LogicalOrinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LogicalOrvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LogicalOrvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_or.html b/docs/build/html/classmlx_1_1core_1_1_logical_or.html new file mode 100644 index 000000000..205a8832d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_or.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LogicalOr Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::LogicalOr Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LogicalOr:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LogicalOr (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LogicalOr()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LogicalOr::LogicalOr (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalOr::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalOr::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LogicalOr::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalOr::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LogicalOr::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LogicalOr::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalOr::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LogicalOr::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_or.png b/docs/build/html/classmlx_1_1core_1_1_logical_or.png new file mode 100644 index 000000000..79dcbcb6f Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_logical_or.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_matmul-members.html b/docs/build/html/classmlx_1_1core_1_1_matmul-members.html new file mode 100644 index 000000000..fc9299e6b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_matmul-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Matmul Member List
+
+
+ +

This is the complete list of members for mlx::core::Matmul, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Matmulvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Matmulvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Matmulinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
Matmul(Stream stream)mlx::core::Matmulinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Matmulinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Matmulvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Matmulvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_matmul.html b/docs/build/html/classmlx_1_1core_1_1_matmul.html new file mode 100644 index 000000000..490ff158a --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_matmul.html @@ -0,0 +1,395 @@ + + + + + + + +MLX: mlx::core::Matmul Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Matmul Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Matmul:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Matmul (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Matmul()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Matmul::Matmul (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Matmul::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Matmul::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Matmul::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Matmul::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Matmul::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Matmul::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_matmul.png b/docs/build/html/classmlx_1_1core_1_1_matmul.png new file mode 100644 index 000000000..eddf94c66 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_matmul.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_maximum-members.html b/docs/build/html/classmlx_1_1core_1_1_maximum-members.html new file mode 100644 index 000000000..d185625c2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_maximum-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Maximum Member List
+
+
+ +

This is the complete list of members for mlx::core::Maximum, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Maximumvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Maximumvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Maximuminlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Maximumvirtual
Maximum(Stream stream)mlx::core::Maximuminlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Maximuminlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Maximuminlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Maximumvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Maximumvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_maximum.html b/docs/build/html/classmlx_1_1core_1_1_maximum.html new file mode 100644 index 000000000..a2374f87c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_maximum.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Maximum Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Maximum Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Maximum:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Maximum (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Maximum()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Maximum::Maximum (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Maximum::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Maximum::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Maximum::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Maximum::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Maximum::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Maximum::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Maximum::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Maximum::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_maximum.png b/docs/build/html/classmlx_1_1core_1_1_maximum.png new file mode 100644 index 000000000..d888f7091 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_maximum.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_minimum-members.html b/docs/build/html/classmlx_1_1core_1_1_minimum-members.html new file mode 100644 index 000000000..34236c406 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_minimum-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Minimum Member List
+
+
+ +

This is the complete list of members for mlx::core::Minimum, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Minimumvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Minimumvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Minimuminlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Minimumvirtual
Minimum(Stream stream)mlx::core::Minimuminlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Minimuminlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Minimuminlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Minimumvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Minimumvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_minimum.html b/docs/build/html/classmlx_1_1core_1_1_minimum.html new file mode 100644 index 000000000..3407dff49 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_minimum.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Minimum Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Minimum Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Minimum:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Minimum (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Minimum()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Minimum::Minimum (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Minimum::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Minimum::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Minimum::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Minimum::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Minimum::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Minimum::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Minimum::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Minimum::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_minimum.png b/docs/build/html/classmlx_1_1core_1_1_minimum.png new file mode 100644 index 000000000..46ca64b2d Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_minimum.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_multiply-members.html b/docs/build/html/classmlx_1_1core_1_1_multiply-members.html new file mode 100644 index 000000000..9754017cd --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_multiply-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Multiply Member List
+
+
+ +

This is the complete list of members for mlx::core::Multiply, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Multiplyvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Multiplyvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Multiplyinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Multiplyvirtual
Multiply(Stream stream)mlx::core::Multiplyinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Multiplyinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Multiplyinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Multiplyvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Multiplyvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_multiply.html b/docs/build/html/classmlx_1_1core_1_1_multiply.html new file mode 100644 index 000000000..85362801c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_multiply.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Multiply Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Multiply Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Multiply:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Multiply (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Multiply()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Multiply::Multiply (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Multiply::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Multiply::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Multiply::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Multiply::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Multiply::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Multiply::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Multiply::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Multiply::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_multiply.png b/docs/build/html/classmlx_1_1core_1_1_multiply.png new file mode 100644 index 000000000..518065cfb Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_multiply.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_negative-members.html b/docs/build/html/classmlx_1_1core_1_1_negative-members.html new file mode 100644 index 000000000..097013f9c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_negative-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Negative Member List
+
+
+ +

This is the complete list of members for mlx::core::Negative, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Negativevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Negativevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Negativeinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Negativevirtual
Negative(Stream stream)mlx::core::Negativeinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Negativeinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Negativeinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Negativevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Negativevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_negative.html b/docs/build/html/classmlx_1_1core_1_1_negative.html new file mode 100644 index 000000000..8d4d4cdb3 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_negative.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Negative Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Negative Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Negative:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Negative (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Negative()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Negative::Negative (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Negative::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Negative::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Negative::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Negative::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Negative::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Negative::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Negative::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Negative::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_negative.png b/docs/build/html/classmlx_1_1core_1_1_negative.png new file mode 100644 index 000000000..f7a0d33a0 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_negative.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html b/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html new file mode 100644 index 000000000..93f8795dc --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::NotEqual Member List
+
+
+ +

This is the complete list of members for mlx::core::NotEqual, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::NotEqualvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::NotEqualvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::NotEqualinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::NotEqualvirtual
NotEqual(Stream stream)mlx::core::NotEqualinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::NotEqualinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::NotEqualinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::NotEqualvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::NotEqualvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_not_equal.html b/docs/build/html/classmlx_1_1core_1_1_not_equal.html new file mode 100644 index 000000000..c394128dc --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_not_equal.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::NotEqual Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::NotEqual Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::NotEqual:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 NotEqual (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ NotEqual()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::NotEqual::NotEqual (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::NotEqual::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::NotEqual::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::NotEqual::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::NotEqual::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::NotEqual::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::NotEqual::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::NotEqual::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::NotEqual::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_not_equal.png b/docs/build/html/classmlx_1_1core_1_1_not_equal.png new file mode 100644 index 000000000..2067d9710 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_not_equal.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_number_of_elements-members.html b/docs/build/html/classmlx_1_1core_1_1_number_of_elements-members.html new file mode 100644 index 000000000..4b786a16f --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_number_of_elements-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::NumberOfElements Member List
+
+
+ +

This is the complete list of members for mlx::core::NumberOfElements, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::NumberOfElementsvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::NumberOfElementsvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::NumberOfElementsvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)mlx::core::NumberOfElementsinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::NumberOfElementsinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::NumberOfElementsinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::NumberOfElementsvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html b/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html new file mode 100644 index 000000000..018531ce9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html @@ -0,0 +1,396 @@ + + + + + + + +MLX: mlx::core::NumberOfElements Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::NumberOfElements Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::NumberOfElements:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 NumberOfElements (Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ NumberOfElements()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::NumberOfElements::NumberOfElements (Stream stream,
std::vector< int > axes,
bool inverted,
Dtype dtype )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::NumberOfElements::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::NumberOfElements::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::NumberOfElements::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::NumberOfElements::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::NumberOfElements::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::NumberOfElements::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_number_of_elements.png b/docs/build/html/classmlx_1_1core_1_1_number_of_elements.png new file mode 100644 index 000000000..2364d3242 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_number_of_elements.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_pad-members.html b/docs/build/html/classmlx_1_1core_1_1_pad-members.html new file mode 100644 index 000000000..083a1a03c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_pad-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Pad Member List
+
+
+ +

This is the complete list of members for mlx::core::Pad, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Padvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Padvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Padvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Padvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)mlx::core::Padinlineexplicit
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Padinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Padvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Padvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_pad.html b/docs/build/html/classmlx_1_1core_1_1_pad.html new file mode 100644 index 000000000..97fa86571 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_pad.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::Pad Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Pad Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Pad:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Pad (Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Pad()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::Pad::Pad (Stream stream,
const std::vector< int > & axes,
const std::vector< int > & low_pad_size,
const std::vector< int > & high_pad_size )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Pad::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Pad::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Pad::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Pad::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Pad::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Pad::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Pad::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_pad.png b/docs/build/html/classmlx_1_1core_1_1_pad.png new file mode 100644 index 000000000..13b64cbff Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_pad.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_partition-members.html b/docs/build/html/classmlx_1_1core_1_1_partition-members.html new file mode 100644 index 000000000..101ee7dc2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_partition-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Partition Member List
+
+
+ +

This is the complete list of members for mlx::core::Partition, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Partitionvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Partitionvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Partitionvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Partitionvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Partitioninlinevirtual
Partition(Stream stream, int kth, int axis)mlx::core::Partitioninlineexplicit
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Partitioninlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Partitionvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Partitionvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_partition.html b/docs/build/html/classmlx_1_1core_1_1_partition.html new file mode 100644 index 000000000..018687d23 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_partition.html @@ -0,0 +1,472 @@ + + + + + + + +MLX: mlx::core::Partition Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Partition Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Partition:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Partition (Stream stream, int kth, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Partition()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Partition::Partition (Stream stream,
int kth,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Partition::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Partition::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Partition::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Partition::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Partition::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Partition::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Partition::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Partition::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_partition.png b/docs/build/html/classmlx_1_1core_1_1_partition.png new file mode 100644 index 000000000..4259b6ba0 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_partition.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_power-members.html b/docs/build/html/classmlx_1_1core_1_1_power-members.html new file mode 100644 index 000000000..6bb8d6502 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_power-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Power Member List
+
+
+ +

This is the complete list of members for mlx::core::Power, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Powervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Powervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Powerinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Powervirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Powerinlinevirtual
Power(Stream stream)mlx::core::Powerinlineexplicit
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Powerinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Powervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Powervirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_power.html b/docs/build/html/classmlx_1_1core_1_1_power.html new file mode 100644 index 000000000..7f1bd4ba4 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_power.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Power Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Power Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Power:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Power (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Power()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Power::Power (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Power::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Power::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Power::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Power::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Power::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Power::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Power::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Power::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_power.png b/docs/build/html/classmlx_1_1core_1_1_power.png new file mode 100644 index 000000000..7ae727a29 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_power.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_primitive-members.html b/docs/build/html/classmlx_1_1core_1_1_primitive-members.html new file mode 100644 index 000000000..0e781c800 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_primitive-members.html @@ -0,0 +1,106 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Primitive Member List
+
+
+ +

This is the complete list of members for mlx::core::Primitive, including all inherited members.

+ + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0mlx::core::Primitivepure virtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0mlx::core::Primitivepure virtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_primitive.html b/docs/build/html/classmlx_1_1core_1_1_primitive.html new file mode 100644 index 000000000..12c2e073c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_primitive.html @@ -0,0 +1,631 @@ + + + + + + + +MLX: mlx::core::Primitive Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Primitive Class Referenceabstract
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Primitive:
+
+
+ + +mlx::core::Compiled +mlx::core::CustomVJP +mlx::core::Depends +mlx::core::DivMod +mlx::core::QRF +mlx::core::SVD +mlx::core::Split +mlx::core::UnaryPrimitive +mlx::core::fast::Custom + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs)=0
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
virtual void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs)=0
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Primitive() [1/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Primitive::Primitive (Stream stream)
+
+inlineexplicit
+
+ +
+
+ +

◆ ~Primitive()

+ +
+
+ + + + + +
+ + + + + + + +
virtual mlx::core::Primitive::~Primitive ()
+
+virtualdefault
+
+ +
+
+ +

◆ Primitive() [2/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Primitive::Primitive (const Primitive & other)
+
+delete
+
+ +
+
+ +

◆ Primitive() [3/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Primitive::Primitive (Primitive && other)
+
+delete
+
+ +
+
+

Member Function Documentation

+ +

◆ device()

+ +
+
+ + + + + +
+ + + + + + + +
const Device & mlx::core::Primitive::device ()
+
+inline
+
+ +

The device the primitive will run on.

+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::Primitive::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+pure virtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implemented in mlx::core::fast::RMSNorm, mlx::core::fast::RMSNormVJP, mlx::core::fast::LayerNorm, mlx::core::fast::LayerNormVJP, mlx::core::fast::RoPE, mlx::core::fast::ScaledDotProductAttention, mlx::core::UnaryPrimitive, mlx::core::Compiled, mlx::core::CustomVJP, mlx::core::Depends, mlx::core::DivMod, mlx::core::Split, mlx::core::QRF, and mlx::core::SVD.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::Primitive::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+pure virtual
+
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::Primitive::is_equivalent (const Primitive & other) const
+
+inlinevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented in mlx::core::fast::ScaledDotProductAttention, mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, and mlx::core::Transpose.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
virtual std::vector< array > mlx::core::Primitive::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+virtual
+
+
+ +

◆ operator=() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
Primitive & mlx::core::Primitive::operator= (const Primitive & other)
+
+delete
+
+ +
+
+ +

◆ operator=() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
Primitive & mlx::core::Primitive::operator= (Primitive && other)
+
+delete
+
+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
virtual std::vector< std::vector< int > > mlx::core::Primitive::output_shapes (const std::vector< array > & inputs)
+
+virtual
+
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
virtual void mlx::core::Primitive::print (std::ostream & os)
+
+pure virtual
+
+ +

Print the primitive.

+ +

Implemented in mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::CustomVJP, mlx::core::Depends, mlx::core::Divide, mlx::core::DivMod, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Load, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, mlx::core::Transpose, mlx::core::QRF, mlx::core::SVD, and mlx::core::Inverse.

+ +
+
+ +

◆ stream()

+ +
+
+ + + + + +
+ + + + + + + +
const Stream & mlx::core::Primitive::stream ()
+
+inline
+
+ +

The stream the primitive will run on.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
virtual std::vector< array > mlx::core::Primitive::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+virtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented in mlx::core::CustomVJP, mlx::core::Depends, mlx::core::fast::Custom, mlx::core::fast::RMSNorm, mlx::core::fast::LayerNorm, mlx::core::fast::RoPE, mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Square, mlx::core::Sqrt, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, and mlx::core::Transpose.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Primitive::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+virtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented in mlx::core::fast::Custom, mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::BitwiseBinary, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, mlx::core::Transpose, mlx::core::SVD, and mlx::core::Inverse.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_primitive.png b/docs/build/html/classmlx_1_1core_1_1_primitive.png new file mode 100644 index 000000000..515b36a66 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_primitive.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html b/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html new file mode 100644 index 000000000..b7b810ed1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::QRF Member List
+
+
+ +

This is the complete list of members for mlx::core::QRF, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::QRFvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::QRFvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::QRFinlinevirtual
QRF(Stream stream)mlx::core::QRFinlineexplicit
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_q_r_f.html b/docs/build/html/classmlx_1_1core_1_1_q_r_f.html new file mode 100644 index 000000000..f65f45b1e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_q_r_f.html @@ -0,0 +1,273 @@ + + + + + + + +MLX: mlx::core::QRF Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::QRF Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::QRF:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 QRF (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ QRF()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::QRF::QRF (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::QRF::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::QRF::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::QRF::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_q_r_f.png b/docs/build/html/classmlx_1_1core_1_1_q_r_f.png new file mode 100644 index 000000000..29056e86a Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_q_r_f.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_quantized_matmul-members.html b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul-members.html new file mode 100644 index 000000000..0ea18fc30 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::QuantizedMatmul Member List
+
+
+ +

This is the complete list of members for mlx::core::QuantizedMatmul, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::QuantizedMatmulvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::QuantizedMatmulvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::QuantizedMatmulvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::QuantizedMatmulvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::QuantizedMatmulinlinevirtual
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)mlx::core::QuantizedMatmulinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::QuantizedMatmulvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::QuantizedMatmulvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html new file mode 100644 index 000000000..dcd295798 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::QuantizedMatmul Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::QuantizedMatmul Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::QuantizedMatmul:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 QuantizedMatmul (Stream stream, int group_size, int bits, bool transpose)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ QuantizedMatmul()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::QuantizedMatmul::QuantizedMatmul (Stream stream,
int group_size,
int bits,
bool transpose )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::QuantizedMatmul::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::QuantizedMatmul::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::QuantizedMatmul::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::QuantizedMatmul::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::QuantizedMatmul::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::QuantizedMatmul::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::QuantizedMatmul::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.png b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.png new file mode 100644 index 000000000..6b7d0c346 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html b/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html new file mode 100644 index 000000000..fbd98bb2c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::RandomBits Member List
+
+
+ +

This is the complete list of members for mlx::core::RandomBits, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::RandomBitsvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::RandomBitsvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::RandomBitsvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::RandomBitsinlinevirtual
RandomBits(Stream stream, const std::vector< int > &shape, int width)mlx::core::RandomBitsinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::RandomBitsvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_random_bits.html b/docs/build/html/classmlx_1_1core_1_1_random_bits.html new file mode 100644 index 000000000..08a208df6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_random_bits.html @@ -0,0 +1,361 @@ + + + + + + + +MLX: mlx::core::RandomBits Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::RandomBits Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::RandomBits:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 RandomBits (Stream stream, const std::vector< int > &shape, int width)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ RandomBits()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::RandomBits::RandomBits (Stream stream,
const std::vector< int > & shape,
int width )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::RandomBits::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::RandomBits::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::RandomBits::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::RandomBits::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::RandomBits::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_random_bits.png b/docs/build/html/classmlx_1_1core_1_1_random_bits.png new file mode 100644 index 000000000..59b478af1 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_random_bits.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_reduce-members.html b/docs/build/html/classmlx_1_1core_1_1_reduce-members.html new file mode 100644 index 000000000..0d81bee6f --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_reduce-members.html @@ -0,0 +1,122 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Reduce Member List
+
+
+ +

This is the complete list of members for mlx::core::Reduce, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
And enum valuemlx::core::Reduce
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Reducevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Reducevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Reducevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
Max enum valuemlx::core::Reduce
Min enum valuemlx::core::Reduce
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
Or enum valuemlx::core::Reduce
output_shapes(const std::vector< array > &inputs) overridemlx::core::Reducevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Reduceinlinevirtual
Prod enum valuemlx::core::Reduce
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)mlx::core::Reduceinlineexplicit
ReduceType enum namemlx::core::Reduce
stream()mlx::core::Primitiveinline
Sum enum valuemlx::core::Reduce
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Reducevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Reducevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_reduce.html b/docs/build/html/classmlx_1_1core_1_1_reduce.html new file mode 100644 index 000000000..40838cb85 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_reduce.html @@ -0,0 +1,472 @@ + + + + + + + +MLX: mlx::core::Reduce Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Types | +Public Member Functions | +List of all members
+
mlx::core::Reduce Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Reduce:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  ReduceType {
+  And +, Or +, Sum +, Prod +,
+  Min +, Max +
+ }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Reduce (Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ ReduceType

+ +
+
+ + + + + + + +
Enumerator
And 
Or 
Sum 
Prod 
Min 
Max 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ Reduce()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Reduce::Reduce (Stream stream,
ReduceType reduce_type,
const std::vector< int > & axes )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Reduce::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Reduce::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Reduce::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Reduce::output_shapes (const std::vector< array > & inputs)
+
+overridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Reduce::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Reduce::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Reduce::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_reduce.png b/docs/build/html/classmlx_1_1core_1_1_reduce.png new file mode 100644 index 000000000..3c46700db Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_reduce.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_remainder-members.html b/docs/build/html/classmlx_1_1core_1_1_remainder-members.html new file mode 100644 index 000000000..e2fefec65 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_remainder-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Remainder Member List
+
+
+ +

This is the complete list of members for mlx::core::Remainder, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Remaindervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Remaindervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Remainderinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Remaindervirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Remainderinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Remainderinlinevirtual
Remainder(Stream stream)mlx::core::Remainderinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Remaindervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Remaindervirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_remainder.html b/docs/build/html/classmlx_1_1core_1_1_remainder.html new file mode 100644 index 000000000..f0fc852ce --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_remainder.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Remainder Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Remainder Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Remainder:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Remainder (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Remainder()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Remainder::Remainder (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Remainder::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Remainder::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Remainder::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Remainder::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Remainder::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Remainder::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Remainder::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Remainder::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_remainder.png b/docs/build/html/classmlx_1_1core_1_1_remainder.png new file mode 100644 index 000000000..898cd6373 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_remainder.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_reshape-members.html b/docs/build/html/classmlx_1_1core_1_1_reshape-members.html new file mode 100644 index 000000000..77ec832ed --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_reshape-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Reshape Member List
+
+
+ +

This is the complete list of members for mlx::core::Reshape, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Reshapevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Reshapevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Reshapevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Reshapevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Reshapeinlinevirtual
Reshape(Stream stream, const std::vector< int > &shape)mlx::core::Reshapeinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Reshapevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Reshapevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_reshape.html b/docs/build/html/classmlx_1_1core_1_1_reshape.html new file mode 100644 index 000000000..c495d9df9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_reshape.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: mlx::core::Reshape Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Reshape Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Reshape:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Reshape (Stream stream, const std::vector< int > &shape)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Reshape()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Reshape::Reshape (Stream stream,
const std::vector< int > & shape )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Reshape::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Reshape::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Reshape::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Reshape::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Reshape::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Reshape::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Reshape::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_reshape.png b/docs/build/html/classmlx_1_1core_1_1_reshape.png new file mode 100644 index 000000000..1c30abb02 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_reshape.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_round-members.html b/docs/build/html/classmlx_1_1core_1_1_round-members.html new file mode 100644 index 000000000..0f74a51e6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_round-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Round Member List
+
+
+ +

This is the complete list of members for mlx::core::Round, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Roundvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Roundvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Roundinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Roundvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Roundinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Roundinlinevirtual
Round(Stream stream)mlx::core::Roundinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Roundvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Roundvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_round.html b/docs/build/html/classmlx_1_1core_1_1_round.html new file mode 100644 index 000000000..61f17205c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_round.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Round Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Round Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Round:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Round (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Round()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Round::Round (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Round::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Round::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Round::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Round::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Round::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Round::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Round::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Round::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_round.png b/docs/build/html/classmlx_1_1core_1_1_round.png new file mode 100644 index 000000000..b24499cb7 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_round.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html b/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html new file mode 100644 index 000000000..034664345 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::SVD Member List
+
+
+ +

This is the complete list of members for mlx::core::SVD, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::SVDvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::SVDvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::SVDinlinevirtual
stream()mlx::core::Primitiveinline
SVD(Stream stream)mlx::core::SVDinlineexplicit
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::SVDvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_s_v_d.html b/docs/build/html/classmlx_1_1core_1_1_s_v_d.html new file mode 100644 index 000000000..86fe0bf9d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_s_v_d.html @@ -0,0 +1,307 @@ + + + + + + + +MLX: mlx::core::SVD Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::SVD Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::SVD:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 SVD (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ SVD()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::SVD::SVD (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::SVD::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::SVD::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::SVD::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::SVD::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_s_v_d.png b/docs/build/html/classmlx_1_1core_1_1_s_v_d.png new file mode 100644 index 000000000..428bbfa87 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_s_v_d.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_scan-members.html b/docs/build/html/classmlx_1_1core_1_1_scan-members.html new file mode 100644 index 000000000..55a51b51c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_scan-members.html @@ -0,0 +1,120 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Scan Member List
+
+
+ +

This is the complete list of members for mlx::core::Scan, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Scanvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Scanvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Scanvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Scanvirtual
Max enum valuemlx::core::Scan
Min enum valuemlx::core::Scan
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Scaninlinevirtual
Prod enum valuemlx::core::Scan
ReduceType enum namemlx::core::Scan
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)mlx::core::Scaninlineexplicit
stream()mlx::core::Primitiveinline
Sum enum valuemlx::core::Scan
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Scanvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Scanvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_scan.html b/docs/build/html/classmlx_1_1core_1_1_scan.html new file mode 100644 index 000000000..75ead8c56 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_scan.html @@ -0,0 +1,483 @@ + + + + + + + +MLX: mlx::core::Scan Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Types | +Public Member Functions | +List of all members
+
mlx::core::Scan Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Scan:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  ReduceType { Max +, Min +, Sum +, Prod + }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Scan (Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ ReduceType

+ +
+
+ + + + +
enum mlx::core::Scan::ReduceType
+
+ + + + + +
Enumerator
Max 
Min 
Sum 
Prod 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ Scan()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
mlx::core::Scan::Scan (Stream stream,
ReduceType reduce_type,
int axis,
bool reverse,
bool inclusive )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Scan::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Scan::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Scan::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Scan::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Scan::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Scan::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Scan::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_scan.png b/docs/build/html/classmlx_1_1core_1_1_scan.png new file mode 100644 index 000000000..6926bd27e Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_scan.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_scatter-members.html b/docs/build/html/classmlx_1_1core_1_1_scatter-members.html new file mode 100644 index 000000000..f6db8360d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_scatter-members.html @@ -0,0 +1,121 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Scatter Member List
+
+
+ +

This is the complete list of members for mlx::core::Scatter, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Scattervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Scattervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Scattervirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Scattervirtual
Max enum valuemlx::core::Scatter
Min enum valuemlx::core::Scatter
None enum valuemlx::core::Scatter
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Scatterinlinevirtual
Prod enum valuemlx::core::Scatter
ReduceType enum namemlx::core::Scatter
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)mlx::core::Scatterinlineexplicit
stream()mlx::core::Primitiveinline
Sum enum valuemlx::core::Scatter
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Scattervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_scatter.html b/docs/build/html/classmlx_1_1core_1_1_scatter.html new file mode 100644 index 000000000..a46a7ff99 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_scatter.html @@ -0,0 +1,444 @@ + + + + + + + +MLX: mlx::core::Scatter Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Types | +Public Member Functions | +List of all members
+
mlx::core::Scatter Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Scatter:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  ReduceType {
+  Max +, Min +, Sum +, Prod +,
+  None +
+ }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Scatter (Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ ReduceType

+ +
+
+ + + + + + +
Enumerator
Max 
Min 
Sum 
Prod 
None 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ Scatter()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Scatter::Scatter (Stream stream,
ReduceType reduce_type,
const std::vector< int > & axes )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Scatter::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Scatter::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Scatter::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Scatter::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Scatter::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Scatter::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_scatter.png b/docs/build/html/classmlx_1_1core_1_1_scatter.png new file mode 100644 index 000000000..bac72d5cb Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_scatter.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_select-members.html b/docs/build/html/classmlx_1_1core_1_1_select-members.html new file mode 100644 index 000000000..3b6c6ecf1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_select-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Select Member List
+
+
+ +

This is the complete list of members for mlx::core::Select, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Selectvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Selectvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Selectinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Selectvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Selectinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Selectinlinevirtual
Select(Stream stream)mlx::core::Selectinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Selectvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Selectvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_select.html b/docs/build/html/classmlx_1_1core_1_1_select.html new file mode 100644 index 000000000..40727d22d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_select.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Select Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Select Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Select:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Select (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Select()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Select::Select (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Select::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Select::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Select::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Select::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Select::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Select::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Select::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Select::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_select.png b/docs/build/html/classmlx_1_1core_1_1_select.png new file mode 100644 index 000000000..86b98868b Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_select.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html b/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html new file mode 100644 index 000000000..e598387d7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sigmoid Member List
+
+
+ +

This is the complete list of members for mlx::core::Sigmoid, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sigmoidvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sigmoidvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sigmoidinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sigmoidvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sigmoidinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sigmoidinlinevirtual
Sigmoid(Stream stream)mlx::core::Sigmoidinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sigmoidvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sigmoidvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sigmoid.html b/docs/build/html/classmlx_1_1core_1_1_sigmoid.html new file mode 100644 index 000000000..66b45144b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sigmoid.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Sigmoid Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Sigmoid Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sigmoid:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sigmoid (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sigmoid()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Sigmoid::Sigmoid (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sigmoid::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sigmoid::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sigmoid::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sigmoid::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sigmoid::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sigmoid::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sigmoid::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sigmoid::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sigmoid.png b/docs/build/html/classmlx_1_1core_1_1_sigmoid.png new file mode 100644 index 000000000..31bcd54a1 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_sigmoid.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_sign-members.html b/docs/build/html/classmlx_1_1core_1_1_sign-members.html new file mode 100644 index 000000000..0bb890485 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sign-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sign Member List
+
+
+ +

This is the complete list of members for mlx::core::Sign, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Signvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Signvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Signinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Signvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Signinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Signinlinevirtual
Sign(Stream stream)mlx::core::Signinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Signvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Signvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sign.html b/docs/build/html/classmlx_1_1core_1_1_sign.html new file mode 100644 index 000000000..bfa86de85 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sign.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Sign Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Sign Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sign:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sign (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sign()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Sign::Sign (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sign::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sign::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sign::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sign::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sign::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sign::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sign::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sign::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sign.png b/docs/build/html/classmlx_1_1core_1_1_sign.png new file mode 100644 index 000000000..1489dbc9c Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_sign.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_sin-members.html b/docs/build/html/classmlx_1_1core_1_1_sin-members.html new file mode 100644 index 000000000..44d282d97 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sin-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sin Member List
+
+
+ +

This is the complete list of members for mlx::core::Sin, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sinvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sinvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sininlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sinvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sininlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sininlinevirtual
Sin(Stream stream)mlx::core::Sininlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sinvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sinvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sin.html b/docs/build/html/classmlx_1_1core_1_1_sin.html new file mode 100644 index 000000000..d4d021f32 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sin.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Sin Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Sin Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sin:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sin (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sin()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Sin::Sin (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sin::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sin::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sin::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sin::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sin::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sin::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sin::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sin::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sin.png b/docs/build/html/classmlx_1_1core_1_1_sin.png new file mode 100644 index 000000000..a532b6c8d Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_sin.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_sinh-members.html b/docs/build/html/classmlx_1_1core_1_1_sinh-members.html new file mode 100644 index 000000000..73748f4b1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sinh-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sinh Member List
+
+
+ +

This is the complete list of members for mlx::core::Sinh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sinhvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sinhvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sinhinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sinhvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sinhinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sinhinlinevirtual
Sinh(Stream stream)mlx::core::Sinhinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sinhvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sinhvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sinh.html b/docs/build/html/classmlx_1_1core_1_1_sinh.html new file mode 100644 index 000000000..dbef15193 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sinh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Sinh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Sinh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sinh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sinh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sinh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Sinh::Sinh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sinh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sinh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sinh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sinh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sinh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sinh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sinh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sinh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sinh.png b/docs/build/html/classmlx_1_1core_1_1_sinh.png new file mode 100644 index 000000000..dcfa33426 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_sinh.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_slice-members.html b/docs/build/html/classmlx_1_1core_1_1_slice-members.html new file mode 100644 index 000000000..de7d4049c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_slice-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Slice Member List
+
+
+ +

This is the complete list of members for mlx::core::Slice, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Slicevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Slicevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Slicevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Slicevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sliceinlinevirtual
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)mlx::core::Sliceinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Slicevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Slicevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_slice.html b/docs/build/html/classmlx_1_1core_1_1_slice.html new file mode 100644 index 000000000..6f6c64555 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_slice.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::Slice Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Slice Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Slice:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Slice (Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Slice()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::Slice::Slice (Stream stream,
const std::vector< int > & start_indices,
const std::vector< int > & end_indices,
const std::vector< int > & strides )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Slice::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Slice::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Slice::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Slice::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Slice::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Slice::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Slice::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_slice.png b/docs/build/html/classmlx_1_1core_1_1_slice.png new file mode 100644 index 000000000..965022c81 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_slice.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html b/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html new file mode 100644 index 000000000..f4abb0ccf --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::SliceUpdate Member List
+
+
+ +

This is the complete list of members for mlx::core::SliceUpdate, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::SliceUpdatevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::SliceUpdatevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::SliceUpdatevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::SliceUpdatevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::SliceUpdateinlinevirtual
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)mlx::core::SliceUpdateinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::SliceUpdatevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::SliceUpdatevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_slice_update.html b/docs/build/html/classmlx_1_1core_1_1_slice_update.html new file mode 100644 index 000000000..b2014c218 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_slice_update.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::SliceUpdate Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::SliceUpdate Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::SliceUpdate:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 SliceUpdate (Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ SliceUpdate()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::SliceUpdate::SliceUpdate (Stream stream,
const std::vector< int > & start_indices,
const std::vector< int > & end_indices,
const std::vector< int > & strides )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::SliceUpdate::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::SliceUpdate::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::SliceUpdate::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::SliceUpdate::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::SliceUpdate::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::SliceUpdate::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::SliceUpdate::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_slice_update.png b/docs/build/html/classmlx_1_1core_1_1_slice_update.png new file mode 100644 index 000000000..25254654e Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_slice_update.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_softmax-members.html b/docs/build/html/classmlx_1_1core_1_1_softmax-members.html new file mode 100644 index 000000000..c3e37e9ff --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_softmax-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Softmax Member List
+
+
+ +

This is the complete list of members for mlx::core::Softmax, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Softmaxvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Softmaxvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Softmaxvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Softmaxvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Softmaxinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Softmaxinlinevirtual
Softmax(Stream stream, bool precise)mlx::core::Softmaxinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Softmaxvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Softmaxvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_softmax.html b/docs/build/html/classmlx_1_1core_1_1_softmax.html new file mode 100644 index 000000000..5a96e8cbb --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_softmax.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::Softmax Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Softmax Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Softmax:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Softmax (Stream stream, bool precise)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Softmax()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Softmax::Softmax (Stream stream,
bool precise )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Softmax::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Softmax::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Softmax::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Softmax::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Softmax::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Softmax::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Softmax::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Softmax::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_softmax.png b/docs/build/html/classmlx_1_1core_1_1_softmax.png new file mode 100644 index 000000000..643fb0d3d Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_softmax.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_sort-members.html b/docs/build/html/classmlx_1_1core_1_1_sort-members.html new file mode 100644 index 000000000..67dcaa64c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sort-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sort Member List
+
+
+ +

This is the complete list of members for mlx::core::Sort, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sortvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sortvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sortvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sortvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sortinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sortinlinevirtual
Sort(Stream stream, int axis)mlx::core::Sortinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sortvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sortvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sort.html b/docs/build/html/classmlx_1_1core_1_1_sort.html new file mode 100644 index 000000000..197d6fdce --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sort.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::Sort Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Sort Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sort:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sort (Stream stream, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sort()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Sort::Sort (Stream stream,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sort::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sort::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sort::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sort::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sort::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sort::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sort::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sort::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sort.png b/docs/build/html/classmlx_1_1core_1_1_sort.png new file mode 100644 index 000000000..fa624d110 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_sort.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_split-members.html b/docs/build/html/classmlx_1_1core_1_1_split-members.html new file mode 100644 index 000000000..9c7cd09dd --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_split-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Split Member List
+
+
+ +

This is the complete list of members for mlx::core::Split, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Splitvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Splitvirtual
is_equivalent(const Primitive &other) const overridemlx::core::Splitvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Splitvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Splitinlinevirtual
Split(Stream stream, const std::vector< int > &indices, int axis)mlx::core::Splitinlineexplicit
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Splitvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Splitvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_split.html b/docs/build/html/classmlx_1_1core_1_1_split.html new file mode 100644 index 000000000..2889694e8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_split.html @@ -0,0 +1,426 @@ + + + + + + + +MLX: mlx::core::Split Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Split Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Split:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Split (Stream stream, const std::vector< int > &indices, int axis)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Split()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Split::Split (Stream stream,
const std::vector< int > & indices,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Split::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Split::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Split::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Split::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Split::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Split::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Split::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_split.png b/docs/build/html/classmlx_1_1core_1_1_split.png new file mode 100644 index 000000000..5b7fd768b Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_split.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html b/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html new file mode 100644 index 000000000..4a2e605ce --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sqrt Member List
+
+
+ +

This is the complete list of members for mlx::core::Sqrt, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sqrtvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sqrtvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sqrtvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sqrtvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sqrtinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sqrtinlinevirtual
Sqrt(Stream stream, bool recip=false)mlx::core::Sqrtinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sqrtvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sqrtvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sqrt.html b/docs/build/html/classmlx_1_1core_1_1_sqrt.html new file mode 100644 index 000000000..c178fcb3d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sqrt.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::Sqrt Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Sqrt Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sqrt:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sqrt (Stream stream, bool recip=false)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sqrt()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Sqrt::Sqrt (Stream stream,
bool recip = false )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sqrt::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sqrt::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sqrt::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sqrt::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sqrt::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sqrt::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sqrt::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sqrt::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sqrt.png b/docs/build/html/classmlx_1_1core_1_1_sqrt.png new file mode 100644 index 000000000..f30bd2b34 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_sqrt.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_square-members.html b/docs/build/html/classmlx_1_1core_1_1_square-members.html new file mode 100644 index 000000000..a447d33e6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_square-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Square Member List
+
+
+ +

This is the complete list of members for mlx::core::Square, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Squarevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Squarevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Squareinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Squarevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Squareinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Squareinlinevirtual
Square(Stream stream)mlx::core::Squareinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Squarevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Squarevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_square.html b/docs/build/html/classmlx_1_1core_1_1_square.html new file mode 100644 index 000000000..24c7cd1c6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_square.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Square Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Square Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Square:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Square (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Square()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Square::Square (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Square::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Square::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Square::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Square::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Square::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Square::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Square::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Square::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_square.png b/docs/build/html/classmlx_1_1core_1_1_square.png new file mode 100644 index 000000000..06ae832b8 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_square.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_stop_gradient-members.html b/docs/build/html/classmlx_1_1core_1_1_stop_gradient-members.html new file mode 100644 index 000000000..4d9503299 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_stop_gradient-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::StopGradient Member List
+
+
+ +

This is the complete list of members for mlx::core::StopGradient, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::StopGradientvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::StopGradientvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::StopGradientinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::StopGradientinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::StopGradientinlinevirtual
StopGradient(Stream stream)mlx::core::StopGradientinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::StopGradientvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html b/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html new file mode 100644 index 000000000..2913a73b7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html @@ -0,0 +1,382 @@ + + + + + + + +MLX: mlx::core::StopGradient Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::StopGradient Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::StopGradient:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 StopGradient (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ StopGradient()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::StopGradient::StopGradient (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::StopGradient::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::StopGradient::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::StopGradient::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::StopGradient::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::StopGradient::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::StopGradient::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_stop_gradient.png b/docs/build/html/classmlx_1_1core_1_1_stop_gradient.png new file mode 100644 index 000000000..082cc974a Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_stop_gradient.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_subtract-members.html b/docs/build/html/classmlx_1_1core_1_1_subtract-members.html new file mode 100644 index 000000000..b6f161513 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_subtract-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Subtract Member List
+
+
+ +

This is the complete list of members for mlx::core::Subtract, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Subtractvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Subtractvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Subtractinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Subtractvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Subtractinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Subtractinlinevirtual
stream()mlx::core::Primitiveinline
Subtract(Stream stream)mlx::core::Subtractinlineexplicit
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Subtractvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Subtractvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_subtract.html b/docs/build/html/classmlx_1_1core_1_1_subtract.html new file mode 100644 index 000000000..d41645a17 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_subtract.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Subtract Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Subtract Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Subtract:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Subtract (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Subtract()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Subtract::Subtract (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Subtract::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Subtract::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Subtract::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Subtract::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Subtract::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Subtract::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Subtract::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Subtract::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_subtract.png b/docs/build/html/classmlx_1_1core_1_1_subtract.png new file mode 100644 index 000000000..9a227b3b3 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_subtract.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_tan-members.html b/docs/build/html/classmlx_1_1core_1_1_tan-members.html new file mode 100644 index 000000000..92188e869 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_tan-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Tan Member List
+
+
+ +

This is the complete list of members for mlx::core::Tan, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Tanvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Tanvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Taninlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Tanvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Taninlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Taninlinevirtual
stream()mlx::core::Primitiveinline
Tan(Stream stream)mlx::core::Taninlineexplicit
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Tanvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Tanvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_tan.html b/docs/build/html/classmlx_1_1core_1_1_tan.html new file mode 100644 index 000000000..b8be27324 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_tan.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Tan Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Tan Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Tan:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Tan (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Tan()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Tan::Tan (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Tan::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Tan::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Tan::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Tan::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Tan::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Tan::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Tan::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Tan::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_tan.png b/docs/build/html/classmlx_1_1core_1_1_tan.png new file mode 100644 index 000000000..613c47aec Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_tan.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_tanh-members.html b/docs/build/html/classmlx_1_1core_1_1_tanh-members.html new file mode 100644 index 000000000..de13c8655 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_tanh-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Tanh Member List
+
+
+ +

This is the complete list of members for mlx::core::Tanh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Tanhvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Tanhvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Tanhinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Tanhvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Tanhinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Tanhinlinevirtual
stream()mlx::core::Primitiveinline
Tanh(Stream stream)mlx::core::Tanhinlineexplicit
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Tanhvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Tanhvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_tanh.html b/docs/build/html/classmlx_1_1core_1_1_tanh.html new file mode 100644 index 000000000..d89fd4ce7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_tanh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Tanh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Tanh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Tanh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Tanh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Tanh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Tanh::Tanh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Tanh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Tanh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Tanh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Tanh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Tanh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Tanh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Tanh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Tanh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_tanh.png b/docs/build/html/classmlx_1_1core_1_1_tanh.png new file mode 100644 index 000000000..8e330c32f Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_tanh.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_transpose-members.html b/docs/build/html/classmlx_1_1core_1_1_transpose-members.html new file mode 100644 index 000000000..e7e2afdae --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_transpose-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Transpose Member List
+
+
+ +

This is the complete list of members for mlx::core::Transpose, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Transposevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Transposevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Transposevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Transposevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Transposeinlinevirtual
stream()mlx::core::Primitiveinline
Transpose(Stream stream, const std::vector< int > &axes)mlx::core::Transposeinlineexplicit
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Transposevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Transposevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_transpose.html b/docs/build/html/classmlx_1_1core_1_1_transpose.html new file mode 100644 index 000000000..839676ac3 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_transpose.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: mlx::core::Transpose Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Transpose Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Transpose:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Transpose (Stream stream, const std::vector< int > &axes)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Transpose()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Transpose::Transpose (Stream stream,
const std::vector< int > & axes )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Transpose::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Transpose::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Transpose::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Transpose::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Transpose::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Transpose::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Transpose::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_transpose.png b/docs/build/html/classmlx_1_1core_1_1_transpose.png new file mode 100644 index 000000000..77c3b2288 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_transpose.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_unary_primitive-members.html b/docs/build/html/classmlx_1_1core_1_1_unary_primitive-members.html new file mode 100644 index 000000000..494b399d9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_unary_primitive-members.html @@ -0,0 +1,114 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::UnaryPrimitive Member List
+
+
+ +

This is the complete list of members for mlx::core::UnaryPrimitive, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &output)=0mlx::core::UnaryPrimitivepure virtual
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &output)=0mlx::core::UnaryPrimitivepure virtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html b/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html new file mode 100644 index 000000000..5644a664e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html @@ -0,0 +1,532 @@ + + + + + + + +MLX: mlx::core::UnaryPrimitive Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::UnaryPrimitive Class Referenceabstract
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::UnaryPrimitive:
+
+
+ + +mlx::core::Primitive +mlx::core::Abs +mlx::core::Add +mlx::core::AddMM +mlx::core::Arange +mlx::core::ArcCos +mlx::core::ArcCosh +mlx::core::ArcSin +mlx::core::ArcSinh +mlx::core::ArcTan +mlx::core::ArcTan2 +mlx::core::ArcTanh +mlx::core::ArgPartition +mlx::core::ArgReduce +mlx::core::ArgSort +mlx::core::AsStrided +mlx::core::AsType +mlx::core::BitwiseBinary +mlx::core::BlockMaskedMM +mlx::core::BlockSparseMM +mlx::core::Broadcast +mlx::core::Ceil +mlx::core::Concatenate +mlx::core::Conjugate +mlx::core::Convolution +mlx::core::Copy +mlx::core::Cos +mlx::core::Cosh +mlx::core::Divide +mlx::core::Equal +mlx::core::Erf +mlx::core::ErfInv +mlx::core::Exp +mlx::core::Expm1 +mlx::core::FFT +mlx::core::Floor +mlx::core::Full +mlx::core::Gather +mlx::core::Greater +mlx::core::GreaterEqual +mlx::core::Inverse +mlx::core::Less +mlx::core::LessEqual +mlx::core::Load +mlx::core::Log +mlx::core::Log1p +mlx::core::LogAddExp +mlx::core::LogicalAnd +mlx::core::LogicalNot +mlx::core::LogicalOr +mlx::core::Matmul +mlx::core::Maximum +mlx::core::Minimum +mlx::core::Multiply +mlx::core::Negative +mlx::core::NotEqual +mlx::core::NumberOfElements +mlx::core::Pad +mlx::core::Partition +mlx::core::Power +mlx::core::QuantizedMatmul +mlx::core::RandomBits +mlx::core::Reduce +mlx::core::Remainder +mlx::core::Reshape +mlx::core::Round +mlx::core::Scan +mlx::core::Scatter +mlx::core::Select +mlx::core::Sigmoid +mlx::core::Sign +mlx::core::Sin +mlx::core::Sinh +mlx::core::Slice +mlx::core::SliceUpdate +mlx::core::Softmax +mlx::core::Sort +mlx::core::Sqrt +mlx::core::Square +mlx::core::StopGradient +mlx::core::Subtract +mlx::core::Tan +mlx::core::Tanh +mlx::core::Transpose +mlx::core::Uniform + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
virtual void eval_cpu (const std::vector< array > &inputs, array &output)=0
 
virtual void eval_gpu (const std::vector< array > &inputs, array &output)=0
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ UnaryPrimitive() [1/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::UnaryPrimitive::UnaryPrimitive (Stream stream)
+
+inlineexplicit
+
+ +

An abstract base class for a primitive with a single output.

+ +
+
+ +

◆ ~UnaryPrimitive()

+ +
+
+ + + + + +
+ + + + + + + +
virtual mlx::core::UnaryPrimitive::~UnaryPrimitive ()
+
+virtualdefault
+
+ +
+
+ +

◆ UnaryPrimitive() [2/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::UnaryPrimitive::UnaryPrimitive (const UnaryPrimitive & other)
+
+delete
+
+ +
+
+ +

◆ UnaryPrimitive() [3/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::UnaryPrimitive::UnaryPrimitive (UnaryPrimitive && other)
+
+delete
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::UnaryPrimitive::eval_cpu (const std::vector< array > & inputs,
array & output )
+
+pure virtual
+
+ +

Implemented in mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Load, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, mlx::core::Transpose, and mlx::core::Inverse.

+ +
+
+ +

◆ eval_cpu() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::UnaryPrimitive::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::UnaryPrimitive::eval_gpu (const std::vector< array > & inputs,
array & output )
+
+pure virtual
+
+ +

Implemented in mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Load, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, mlx::core::Transpose, and mlx::core::Inverse.

+ +
+
+ +

◆ eval_gpu() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::UnaryPrimitive::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ operator=() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
UnaryPrimitive & mlx::core::UnaryPrimitive::operator= (const UnaryPrimitive & other)
+
+delete
+
+ +
+
+ +

◆ operator=() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
UnaryPrimitive & mlx::core::UnaryPrimitive::operator= (UnaryPrimitive && other)
+
+delete
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_unary_primitive.png b/docs/build/html/classmlx_1_1core_1_1_unary_primitive.png new file mode 100644 index 000000000..191d394b3 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_unary_primitive.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1_uniform-members.html b/docs/build/html/classmlx_1_1core_1_1_uniform-members.html new file mode 100644 index 000000000..97eece0ca --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_uniform-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Uniform Member List
+
+
+ +

This is the complete list of members for mlx::core::Uniform, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Uniformvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Uniformvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Uniforminlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Uniforminlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
Uniform(Stream stream)mlx::core::Uniforminlineexplicit
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Uniformvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_uniform.html b/docs/build/html/classmlx_1_1core_1_1_uniform.html new file mode 100644 index 000000000..45442f1a6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_uniform.html @@ -0,0 +1,352 @@ + + + + + + + +MLX: mlx::core::Uniform Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::Uniform Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Uniform:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Uniform (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Uniform()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Uniform::Uniform (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Uniform::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Uniform::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Uniform::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Uniform::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Uniform::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_uniform.png b/docs/build/html/classmlx_1_1core_1_1_uniform.png new file mode 100644 index 000000000..1fe539076 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1_uniform.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator-members.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator-members.html new file mode 100644 index 000000000..b98657fd7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator-members.html @@ -0,0 +1,98 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::allocator::Allocator Member List
+
+
+ +

This is the complete list of members for mlx::core::allocator::Allocator, including all inherited members.

+ + + + + + + + + +
Allocator()=defaultmlx::core::allocator::Allocator
Allocator(const Allocator &other)=deletemlx::core::allocator::Allocator
Allocator(Allocator &&other)=deletemlx::core::allocator::Allocator
free(Buffer buffer)=0mlx::core::allocator::Allocatorpure virtual
malloc(size_t size, bool allow_swap=false)=0mlx::core::allocator::Allocatorpure virtual
operator=(const Allocator &other)=deletemlx::core::allocator::Allocator
operator=(Allocator &&other)=deletemlx::core::allocator::Allocator
~Allocator()=defaultmlx::core::allocator::Allocatorvirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.html new file mode 100644 index 000000000..4d6d58379 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.html @@ -0,0 +1,338 @@ + + + + + + + +MLX: mlx::core::allocator::Allocator Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::allocator::Allocator Class Referenceabstract
+
+
+ +

#include <allocator.h>

+
+Inheritance diagram for mlx::core::allocator::Allocator:
+
+
+ + +mlx::core::allocator::CommonAllocator +mlx::core::metal::MetalAllocator + +
+ + + + + + + + + + + + + + + + + + + +

+Public Member Functions

virtual Buffer malloc (size_t size, bool allow_swap=false)=0
 Abstract base class for a memory allocator.
 
virtual void free (Buffer buffer)=0
 
 Allocator ()=default
 
 Allocator (const Allocator &other)=delete
 
 Allocator (Allocator &&other)=delete
 
Allocatoroperator= (const Allocator &other)=delete
 
Allocatoroperator= (Allocator &&other)=delete
 
virtual ~Allocator ()=default
 
+

Constructor & Destructor Documentation

+ +

◆ Allocator() [1/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::allocator::Allocator::Allocator ()
+
+default
+
+ +
+
+ +

◆ Allocator() [2/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::allocator::Allocator::Allocator (const Allocator & other)
+
+delete
+
+ +
+
+ +

◆ Allocator() [3/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::allocator::Allocator::Allocator (Allocator && other)
+
+delete
+
+ +
+
+ +

◆ ~Allocator()

+ +
+
+ + + + + +
+ + + + + + + +
virtual mlx::core::allocator::Allocator::~Allocator ()
+
+virtualdefault
+
+ +
+
+

Member Function Documentation

+ +

◆ free()

+ +
+
+ + + + + +
+ + + + + + + +
virtual void mlx::core::allocator::Allocator::free (Buffer buffer)
+
+pure virtual
+
+
+ +

◆ malloc()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual Buffer mlx::core::allocator::Allocator::malloc (size_t size,
bool allow_swap = false )
+
+pure virtual
+
+ +

Abstract base class for a memory allocator.

+ +

Implemented in mlx::core::allocator::CommonAllocator, and mlx::core::metal::MetalAllocator.

+ +
+
+ +

◆ operator=() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
Allocator & mlx::core::allocator::Allocator::operator= (Allocator && other)
+
+delete
+
+ +
+
+ +

◆ operator=() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
Allocator & mlx::core::allocator::Allocator::operator= (const Allocator & other)
+
+delete
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.png b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.png new file mode 100644 index 000000000..a57dd9471 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer-members.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer-members.html new file mode 100644 index 000000000..e3bb6d0e3 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer-members.html @@ -0,0 +1,94 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::allocator::Buffer Member List
+
+
+ +

This is the complete list of members for mlx::core::allocator::Buffer, including all inherited members.

+ + + + + +
Buffer(void *ptr)mlx::core::allocator::Bufferinline
ptr() constmlx::core::allocator::Bufferinline
ptr()mlx::core::allocator::Bufferinline
raw_ptr()mlx::core::allocator::Buffer
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer.html new file mode 100644 index 000000000..263092690 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer.html @@ -0,0 +1,201 @@ + + + + + + + +MLX: mlx::core::allocator::Buffer Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::allocator::Buffer Class Reference
+
+
+ +

#include <allocator.h>

+ + + + + + + + + + +

+Public Member Functions

 Buffer (void *ptr)
 
void * raw_ptr ()
 
const void * ptr () const
 
void * ptr ()
 
+

Constructor & Destructor Documentation

+ +

◆ Buffer()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::allocator::Buffer::Buffer (void * ptr)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ ptr() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
void * mlx::core::allocator::Buffer::ptr ()
+
+inline
+
+ +
+
+ +

◆ ptr() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const void * mlx::core::allocator::Buffer::ptr () const
+
+inline
+
+ +
+
+ +

◆ raw_ptr()

+ +
+
+ + + + + + + +
void * mlx::core::allocator::Buffer::raw_ptr ()
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator-members.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator-members.html new file mode 100644 index 000000000..7a70d6422 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator-members.html @@ -0,0 +1,99 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::allocator::CommonAllocator Member List
+
+
+ +

This is the complete list of members for mlx::core::allocator::CommonAllocator, including all inherited members.

+ + + + + + + + + + +
allocatormlx::core::allocator::CommonAllocatorfriend
Allocator()=defaultmlx::core::allocator::Allocator
Allocator(const Allocator &other)=deletemlx::core::allocator::Allocator
Allocator(Allocator &&other)=deletemlx::core::allocator::Allocator
free(Buffer buffer) overridemlx::core::allocator::CommonAllocatorvirtual
malloc(size_t size, bool allow_swap=false) overridemlx::core::allocator::CommonAllocatorvirtual
operator=(const Allocator &other)=deletemlx::core::allocator::Allocator
operator=(Allocator &&other)=deletemlx::core::allocator::Allocator
~Allocator()=defaultmlx::core::allocator::Allocatorvirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.html new file mode 100644 index 000000000..8d6363adb --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.html @@ -0,0 +1,219 @@ + + + + + + + +MLX: mlx::core::allocator::CommonAllocator Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Friends | +List of all members
+
mlx::core::allocator::CommonAllocator Class Reference
+
+
+ +

#include <allocator.h>

+
+Inheritance diagram for mlx::core::allocator::CommonAllocator:
+
+
+ + +mlx::core::allocator::Allocator + +
+ + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

virtual Buffer malloc (size_t size, bool allow_swap=false) override
 A general CPU allocator.
 
virtual void free (Buffer buffer) override
 
- Public Member Functions inherited from mlx::core::allocator::Allocator
 Allocator ()=default
 
 Allocator (const Allocator &other)=delete
 
 Allocator (Allocator &&other)=delete
 
Allocatoroperator= (const Allocator &other)=delete
 
Allocatoroperator= (Allocator &&other)=delete
 
virtual ~Allocator ()=default
 
+ + + +

+Friends

Allocatorallocator ()
 
+

Member Function Documentation

+ +

◆ free()

+ +
+
+ + + + + +
+ + + + + + + +
virtual void mlx::core::allocator::CommonAllocator::free (Buffer buffer)
+
+overridevirtual
+
+
+ +

◆ malloc()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual Buffer mlx::core::allocator::CommonAllocator::malloc (size_t size,
bool allow_swap = false )
+
+overridevirtual
+
+ +

A general CPU allocator.

+ +

Implements mlx::core::allocator::Allocator.

+ +
+
+

Friends And Related Symbol Documentation

+ +

◆ allocator

+ +
+
+ + + + + +
+ + + + + + + +
Allocator & allocator ()
+
+friend
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.png b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.png new file mode 100644 index 000000000..8b609c844 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1array-members.html b/docs/build/html/classmlx_1_1core_1_1array-members.html new file mode 100644 index 000000000..bee74b0e3 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1array-members.html @@ -0,0 +1,159 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::array Member List
+
+
+ +

This is the complete list of members for mlx::core::array, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
array(T val, Dtype dtype=TypeToDtype< T >())mlx::core::arrayexplicit
array(const std::complex< float > &val, Dtype dtype=complex64)mlx::core::arrayexplicit
array(It data, std::vector< int > shape, Dtype dtype=TypeToDtype< typename std::iterator_traits< It >::value_type >())mlx::core::array
array(std::initializer_list< T > data, Dtype dtype=TypeToDtype< T >())mlx::core::array
array(std::initializer_list< float > data)mlx::core::array
array(std::initializer_list< int > data, Dtype dtype)mlx::core::array
array(std::initializer_list< T > data, std::vector< int > shape, Dtype dtype=TypeToDtype< T >())mlx::core::array
array(allocator::Buffer data, std::vector< int > shape, Dtype dtype, deleter_t deleter=allocator::free)mlx::core::array
array(const array &other)=defaultmlx::core::array
array(array &&other)=defaultmlx::core::array
array(std::vector< int > shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)mlx::core::array
attach_event(Event e) constmlx::core::arrayinline
available enum valuemlx::core::array
begin() constmlx::core::arrayinline
buffer()mlx::core::arrayinline
buffer() constmlx::core::arrayinline
copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)mlx::core::array
copy_shared_buffer(const array &other)mlx::core::array
data()mlx::core::arrayinline
data() constmlx::core::arrayinline
data_shared_ptr() constmlx::core::arrayinline
data_size() constmlx::core::arrayinline
detach()mlx::core::array
dtype() constmlx::core::arrayinline
end() constmlx::core::arrayinline
eval()mlx::core::array
event() constmlx::core::arrayinline
flags() constmlx::core::arrayinline
has_primitive() constmlx::core::arrayinline
id() constmlx::core::arrayinline
inputs() constmlx::core::arrayinline
inputs()mlx::core::arrayinline
is_available() constmlx::core::arrayinline
is_donatable() constmlx::core::arrayinline
is_tracer() constmlx::core::array
item()mlx::core::array
item() constmlx::core::array
itemsize() constmlx::core::arrayinline
make_arrays(std::vector< std::vector< int > > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)mlx::core::arraystatic
move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)mlx::core::array
move_shared_buffer(array other)mlx::core::array
nbytes() constmlx::core::arrayinline
ndim() constmlx::core::arrayinline
operator=(const array &other) &&=deletemlx::core::array
operator=(array &&other) &&=deletemlx::core::array
operator=(array &&other) &=defaultmlx::core::array
operator=(const array &other) &mlx::core::arrayinline
outputs() constmlx::core::arrayinline
overwrite_descriptor(const array &other)mlx::core::arrayinline
primitive() constmlx::core::arrayinline
primitive_id() constmlx::core::arrayinline
primitive_ptr() constmlx::core::arrayinline
scheduled enum valuemlx::core::array
set_data(allocator::Buffer buffer, deleter_t d=allocator::free)mlx::core::array
set_data(allocator::Buffer buffer, size_t data_size, std::vector< size_t > strides, Flags flags, deleter_t d=allocator::free)mlx::core::array
set_siblings(std::vector< array > siblings, uint16_t position)mlx::core::arrayinline
set_status(Status s) constmlx::core::arrayinline
set_tracer(bool is_tracer)mlx::core::arrayinline
shape() constmlx::core::arrayinline
shape(int dim) constmlx::core::arrayinline
siblings() constmlx::core::arrayinline
siblings()mlx::core::arrayinline
size() constmlx::core::arrayinline
status() constmlx::core::arrayinline
Status enum namemlx::core::array
strides() constmlx::core::arrayinline
strides(int dim) constmlx::core::arrayinline
unscheduled enum valuemlx::core::array
~array()mlx::core::array
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1array.html b/docs/build/html/classmlx_1_1core_1_1array.html new file mode 100644 index 000000000..25f218b83 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1array.html @@ -0,0 +1,2000 @@ + + + + + + + +MLX: mlx::core::array Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Public Types | +Public Member Functions | +Static Public Member Functions | +List of all members
+
mlx::core::array Class Reference
+
+
+ +

#include <array.h>

+ + + + + + + + +

+Classes

struct  ArrayIterator
 
struct  Data
 
struct  Flags
 
+ + + +

+Public Types

enum  Status { unscheduled +, scheduled +, available + }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

template<typename T >
 array (T val, Dtype dtype=TypeToDtype< T >())
 Construct a scalar array with zero dimensions.
 
 array (const std::complex< float > &val, Dtype dtype=complex64)
 
template<typename It >
 array (It data, std::vector< int > shape, Dtype dtype=TypeToDtype< typename std::iterator_traits< It >::value_type >())
 
template<typename T >
 array (std::initializer_list< T > data, Dtype dtype=TypeToDtype< T >())
 
 array (std::initializer_list< float > data)
 
 array (std::initializer_list< int > data, Dtype dtype)
 
template<typename T >
 array (std::initializer_list< T > data, std::vector< int > shape, Dtype dtype=TypeToDtype< T >())
 
 array (allocator::Buffer data, std::vector< int > shape, Dtype dtype, deleter_t deleter=allocator::free)
 
arrayoperator= (const array &other) &&=delete
 Assignment to rvalue does not compile.
 
arrayoperator= (array &&other) &&=delete
 
arrayoperator= (array &&other) &=default
 Default copy and move constructors otherwise.
 
 array (const array &other)=default
 
 array (array &&other)=default
 
arrayoperator= (const array &other) &
 
size_t itemsize () const
 The size of the array's datatype in bytes.
 
size_t size () const
 The number of elements in the array.
 
size_t nbytes () const
 The number of bytes in the array.
 
size_t ndim () const
 The number of dimensions of the array.
 
const std::vector< int > & shape () const
 The shape of the array as a vector of integers.
 
int shape (int dim) const
 Get the size of the corresponding dimension.
 
const std::vector< size_t > & strides () const
 The strides of the array.
 
size_t strides (int dim) const
 Get the stride of the corresponding dimension.
 
Dtype dtype () const
 Get the arrays data type.
 
void eval ()
 Evaluate the array.
 
template<typename T >
item ()
 Get the value from a scalar array.
 
template<typename T >
item () const
 
ArrayIterator begin () const
 
ArrayIterator end () const
 
 array (std::vector< int > shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)
 The following methods should be used with caution.
 
std::uintptr_t id () const
 A unique identifier for an array.
 
std::uintptr_t primitive_id () const
 A unique identifier for an arrays primitive.
 
Primitiveprimitive () const
 The array's primitive.
 
std::shared_ptr< Primitive > & primitive_ptr () const
 A shared pointer to the array's primitive.
 
bool has_primitive () const
 Check if the array has an attached primitive or is a leaf node.
 
const std::vector< array > & inputs () const
 The array's inputs.
 
std::vector< array > & inputs ()
 
bool is_donatable () const
 True indicates the arrays buffer is safe to reuse.
 
const std::vector< array > & siblings () const
 The array's siblings.
 
std::vector< array > & siblings ()
 The array's siblings.
 
void set_siblings (std::vector< array > siblings, uint16_t position)
 
std::vector< arrayoutputs () const
 The outputs of the array's primitive (i.e.
 
void detach ()
 Detach the array from the graph.
 
const Flagsflags () const
 Get the Flags bit-field.
 
size_t data_size () const
 The size (in elements) of the underlying buffer the array points to.
 
allocator::Bufferbuffer ()
 
const allocator::Bufferbuffer () const
 
std::shared_ptr< Datadata_shared_ptr () const
 
template<typename T >
T * data ()
 
template<typename T >
const T * data () const
 
bool is_available () const
 
const Status status () const
 
void set_status (Status s) const
 
Eventevent () const
 
void attach_event (Event e) const
 
void set_tracer (bool is_tracer)
 
bool is_tracer () const
 
void set_data (allocator::Buffer buffer, deleter_t d=allocator::free)
 
void set_data (allocator::Buffer buffer, size_t data_size, std::vector< size_t > strides, Flags flags, deleter_t d=allocator::free)
 
void copy_shared_buffer (const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
 
void copy_shared_buffer (const array &other)
 
void move_shared_buffer (array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
 
void move_shared_buffer (array other)
 
void overwrite_descriptor (const array &other)
 
 ~array ()
 
+ + + +

+Static Public Member Functions

static std::vector< arraymake_arrays (std::vector< std::vector< int > > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)
 
+

Member Enumeration Documentation

+ +

◆ Status

+ +
+
+ + + + +
enum mlx::core::array::Status
+
+ + + + +
Enumerator
unscheduled 
scheduled 
available 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ array() [1/11]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + +
mlx::core::array::array (T val,
Dtype dtype = TypeToDtype<T>() )
+
+explicit
+
+ +

Construct a scalar array with zero dimensions.

+ +
+
+ +

◆ array() [2/11]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::array::array (const std::complex< float > & val,
Dtype dtype = complex64 )
+
+explicit
+
+ +
+
+ +

◆ array() [3/11]

+ +
+
+
+template<typename It >
+ + + + + + + + + + + + + + + + +
mlx::core::array::array (It data,
std::vector< int > shape,
Dtype dtype = TypeToDtype<typename std::iterator_traits<It>::value_type>() )
+
+ +
+
+ +

◆ array() [4/11]

+ +
+
+
+template<typename T >
+ + + + + + + + + + + +
mlx::core::array::array (std::initializer_list< T > data,
Dtype dtype = TypeToDtype<T>() )
+
+ +
+
+ +

◆ array() [5/11]

+ +
+
+ + + + + + + +
mlx::core::array::array (std::initializer_list< float > data)
+
+ +
+
+ +

◆ array() [6/11]

+ +
+
+ + + + + + + + + + + +
mlx::core::array::array (std::initializer_list< int > data,
Dtype dtype )
+
+ +
+
+ +

◆ array() [7/11]

+ +
+
+
+template<typename T >
+ + + + + + + + + + + + + + + + +
mlx::core::array::array (std::initializer_list< T > data,
std::vector< int > shape,
Dtype dtype = TypeToDtype<T>() )
+
+ +
+
+ +

◆ array() [8/11]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::array::array (allocator::Buffer data,
std::vector< int > shape,
Dtype dtype,
deleter_t deleter = allocator::free )
+
+ +
+
+ +

◆ array() [9/11]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::array::array (const array & other)
+
+default
+
+ +
+
+ +

◆ array() [10/11]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::array::array (array && other)
+
+default
+
+ +
+
+ +

◆ array() [11/11]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::array::array (std::vector< int > shape,
Dtype dtype,
std::shared_ptr< Primitive > primitive,
std::vector< array > inputs )
+
+ +

The following methods should be used with caution.

+

They are intended for use by the backend implementation and the API may change.

+ +
+
+ +

◆ ~array()

+ +
+
+ + + + + + + +
mlx::core::array::~array ()
+
+ +
+
+

Member Function Documentation

+ +

◆ attach_event()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::array::attach_event (Event e) const
+
+inline
+
+ +
+
+ +

◆ begin()

+ +
+
+ + + + + +
+ + + + + + + +
ArrayIterator mlx::core::array::begin () const
+
+inline
+
+ +
+
+ +

◆ buffer() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
allocator::Buffer & mlx::core::array::buffer ()
+
+inline
+
+ +
+
+ +

◆ buffer() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const allocator::Buffer & mlx::core::array::buffer () const
+
+inline
+
+ +
+
+ +

◆ copy_shared_buffer() [1/2]

+ +
+
+ + + + + + + +
void mlx::core::array::copy_shared_buffer (const array & other)
+
+ +
+
+ +

◆ copy_shared_buffer() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void mlx::core::array::copy_shared_buffer (const array & other,
const std::vector< size_t > & strides,
Flags flags,
size_t data_size,
size_t offset = 0 )
+
+ +
+
+ +

◆ data() [1/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
T * mlx::core::array::data ()
+
+inline
+
+ +
+
+ +

◆ data() [2/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
const T * mlx::core::array::data () const
+
+inline
+
+ +
+
+ +

◆ data_shared_ptr()

+ +
+
+ + + + + +
+ + + + + + + +
std::shared_ptr< Data > mlx::core::array::data_shared_ptr () const
+
+inline
+
+ +
+
+ +

◆ data_size()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::data_size () const
+
+inline
+
+ +

The size (in elements) of the underlying buffer the array points to.

+ +
+
+ +

◆ detach()

+ +
+
+ + + + + + + +
void mlx::core::array::detach ()
+
+ +

Detach the array from the graph.

+ +
+
+ +

◆ dtype()

+ +
+
+ + + + + +
+ + + + + + + +
Dtype mlx::core::array::dtype () const
+
+inline
+
+ +

Get the arrays data type.

+ +
+
+ +

◆ end()

+ +
+
+ + + + + +
+ + + + + + + +
ArrayIterator mlx::core::array::end () const
+
+inline
+
+ +
+
+ +

◆ eval()

+ +
+
+ + + + + + + +
void mlx::core::array::eval ()
+
+ +

Evaluate the array.

+ +
+
+ +

◆ event()

+ +
+
+ + + + + +
+ + + + + + + +
Event & mlx::core::array::event () const
+
+inline
+
+ +
+
+ +

◆ flags()

+ +
+
+ + + + + +
+ + + + + + + +
const Flags & mlx::core::array::flags () const
+
+inline
+
+ +

Get the Flags bit-field.

+ +
+
+ +

◆ has_primitive()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::array::has_primitive () const
+
+inline
+
+ +

Check if the array has an attached primitive or is a leaf node.

+ +
+
+ +

◆ id()

+ +
+
+ + + + + +
+ + + + + + + +
std::uintptr_t mlx::core::array::id () const
+
+inline
+
+ +

A unique identifier for an array.

+ +
+
+ +

◆ inputs() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< array > & mlx::core::array::inputs ()
+
+inline
+
+ +
+
+ +

◆ inputs() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const std::vector< array > & mlx::core::array::inputs () const
+
+inline
+
+ +

The array's inputs.

+ +
+
+ +

◆ is_available()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::array::is_available () const
+
+inline
+
+ +
+
+ +

◆ is_donatable()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::array::is_donatable () const
+
+inline
+
+ +

True indicates the arrays buffer is safe to reuse.

+ +
+
+ +

◆ is_tracer()

+ +
+
+ + + + + + + +
bool mlx::core::array::is_tracer () const
+
+ +
+
+ +

◆ item() [1/2]

+ +
+
+
+template<typename T >
+ + + + + + + +
T mlx::core::array::item ()
+
+ +

Get the value from a scalar array.

+ +
+
+ +

◆ item() [2/2]

+ +
+
+
+template<typename T >
+ + + + + + + +
T mlx::core::array::item () const
+
+ +
+
+ +

◆ itemsize()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::itemsize () const
+
+inline
+
+ +

The size of the array's datatype in bytes.

+ +
+
+ +

◆ make_arrays()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
static std::vector< array > mlx::core::array::make_arrays (std::vector< std::vector< int > > shapes,
const std::vector< Dtype > & dtypes,
const std::shared_ptr< Primitive > & primitive,
const std::vector< array > & inputs )
+
+static
+
+ +
+
+ +

◆ move_shared_buffer() [1/2]

+ +
+
+ + + + + + + +
void mlx::core::array::move_shared_buffer (array other)
+
+ +
+
+ +

◆ move_shared_buffer() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void mlx::core::array::move_shared_buffer (array other,
const std::vector< size_t > & strides,
Flags flags,
size_t data_size,
size_t offset = 0 )
+
+ +
+
+ +

◆ nbytes()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::nbytes () const
+
+inline
+
+ +

The number of bytes in the array.

+ +
+
+ +

◆ ndim()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::ndim () const
+
+inline
+
+ +

The number of dimensions of the array.

+ +
+
+ +

◆ operator=() [1/4]

+ +
+
+ + + + + +
+ + + + + + + +
array & mlx::core::array::operator= (array && other) &&
+
+delete
+
+ +
+
+ +

◆ operator=() [2/4]

+ +
+
+ + + + + +
+ + + + + + + +
array & mlx::core::array::operator= (array && other) &
+
+default
+
+ +

Default copy and move constructors otherwise.

+ +
+
+ +

◆ operator=() [3/4]

+ +
+
+ + + + + +
+ + + + + + + +
array & mlx::core::array::operator= (const array & other) &
+
+inline
+
+ +
+
+ +

◆ operator=() [4/4]

+ +
+
+ + + + + +
+ + + + + + + +
array & mlx::core::array::operator= (const array & other) &&
+
+delete
+
+ +

Assignment to rvalue does not compile.

+ +
+
+ +

◆ outputs()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< array > mlx::core::array::outputs () const
+
+inline
+
+ +

The outputs of the array's primitive (i.e.

+

this array and its siblings) in the order the primitive expects.

+ +
+
+ +

◆ overwrite_descriptor()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::array::overwrite_descriptor (const array & other)
+
+inline
+
+ +
+
+ +

◆ primitive()

+ +
+
+ + + + + +
+ + + + + + + +
Primitive & mlx::core::array::primitive () const
+
+inline
+
+ +

The array's primitive.

+ +
+
+ +

◆ primitive_id()

+ +
+
+ + + + + +
+ + + + + + + +
std::uintptr_t mlx::core::array::primitive_id () const
+
+inline
+
+ +

A unique identifier for an arrays primitive.

+ +
+
+ +

◆ primitive_ptr()

+ +
+
+ + + + + +
+ + + + + + + +
std::shared_ptr< Primitive > & mlx::core::array::primitive_ptr () const
+
+inline
+
+ +

A shared pointer to the array's primitive.

+ +
+
+ +

◆ set_data() [1/2]

+ +
+
+ + + + + + + + + + + +
void mlx::core::array::set_data (allocator::Buffer buffer,
deleter_t d = allocator::free )
+
+ +
+
+ +

◆ set_data() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void mlx::core::array::set_data (allocator::Buffer buffer,
size_t data_size,
std::vector< size_t > strides,
Flags flags,
deleter_t d = allocator::free )
+
+ +
+
+ +

◆ set_siblings()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::array::set_siblings (std::vector< array > siblings,
uint16_t position )
+
+inline
+
+ +
+
+ +

◆ set_status()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::array::set_status (Status s) const
+
+inline
+
+ +
+
+ +

◆ set_tracer()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::array::set_tracer (bool is_tracer)
+
+inline
+
+ +
+
+ +

◆ shape() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
const std::vector< int > & mlx::core::array::shape () const
+
+inline
+
+ +

The shape of the array as a vector of integers.

+ +
+
+ +

◆ shape() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
int mlx::core::array::shape (int dim) const
+
+inline
+
+ +

Get the size of the corresponding dimension.

+

This function supports negative indexing and provides bounds checking.

+ +
+
+ +

◆ siblings() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< array > & mlx::core::array::siblings ()
+
+inline
+
+ +

The array's siblings.

+ +
+
+ +

◆ siblings() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const std::vector< array > & mlx::core::array::siblings () const
+
+inline
+
+ +

The array's siblings.

+ +
+
+ +

◆ size()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::size () const
+
+inline
+
+ +

The number of elements in the array.

+ +
+
+ +

◆ status()

+ +
+
+ + + + + +
+ + + + + + + +
const Status mlx::core::array::status () const
+
+inline
+
+ +
+
+ +

◆ strides() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
const std::vector< size_t > & mlx::core::array::strides () const
+
+inline
+
+ +

The strides of the array.

+ +
+
+ +

◆ strides() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::strides (int dim) const
+
+inline
+
+ +

Get the stride of the corresponding dimension.

+

This function supports negative indexing and provides bounds checking.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom-members.html new file mode 100644 index 000000000..2acd0e4fb --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::Custom Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::Custom, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0mlx::core::Primitivepure virtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0mlx::core::Primitivepure virtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::Customvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.html new file mode 100644 index 000000000..35642a556 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.html @@ -0,0 +1,306 @@ + + + + + + + +MLX: mlx::core::fast::Custom Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::fast::Custom Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::Custom:
+
+
+ + +mlx::core::Primitive +mlx::core::fast::LayerNorm +mlx::core::fast::LayerNormVJP +mlx::core::fast::RMSNorm +mlx::core::fast::RMSNormVJP +mlx::core::fast::RoPE +mlx::core::fast::ScaledDotProductAttention + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs)=0
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
virtual void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs)=0
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Custom()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::fast::Custom::Custom (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
virtual std::vector< array > mlx::core::fast::Custom::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
virtual std::vector< array > mlx::core::fast::Custom::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +

Reimplemented in mlx::core::fast::RMSNorm, mlx::core::fast::LayerNorm, and mlx::core::fast::RoPE.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::fast::Custom::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.png new file mode 100644 index 000000000..6d384f81c Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm-members.html new file mode 100644 index 000000000..7c29baf14 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm-members.html @@ -0,0 +1,109 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::LayerNorm Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::LayerNorm, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::LayerNorm
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::LayerNorminlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::LayerNormvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
LayerNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)mlx::core::fast::LayerNorminline
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::LayerNormvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.html new file mode 100644 index 000000000..dd9db97ac --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.html @@ -0,0 +1,327 @@ + + + + + + + +MLX: mlx::core::fast::LayerNorm Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::fast::LayerNorm Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::LayerNorm:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LayerNorm (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
 DEFINE_PRINT (LayerNorm) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LayerNorm()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::fast::LayerNorm::LayerNorm (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
float eps )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::LayerNorm::DEFINE_PRINT (LayerNorm ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::LayerNorm::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::LayerNorm::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::fast::LayerNorm::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::fast::Custom.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.png new file mode 100644 index 000000000..202a404d0 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p-members.html new file mode 100644 index 000000000..fc31cfe70 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p-members.html @@ -0,0 +1,109 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::LayerNormVJP Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::LayerNormVJP, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(LayerNormVJP) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::LayerNormVJP
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::LayerNormVJPinlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::LayerNormVJPvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
LayerNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)mlx::core::fast::LayerNormVJPinline
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::Customvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.html new file mode 100644 index 000000000..81a4b6e49 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.html @@ -0,0 +1,284 @@ + + + + + + + +MLX: mlx::core::fast::LayerNormVJP Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::fast::LayerNormVJP Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::LayerNormVJP:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LayerNormVJP (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
 DEFINE_PRINT (LayerNormVJP) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LayerNormVJP()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::fast::LayerNormVJP::LayerNormVJP (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
float eps )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::LayerNormVJP::DEFINE_PRINT (LayerNormVJP ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::LayerNormVJP::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::LayerNormVJP::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.png new file mode 100644 index 000000000..e81afc142 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm-members.html new file mode 100644 index 000000000..ab88c81d7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm-members.html @@ -0,0 +1,109 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::RMSNorm Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::RMSNorm, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::RMSNorm
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RMSNorminlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RMSNormvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
RMSNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)mlx::core::fast::RMSNorminline
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::RMSNormvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.html new file mode 100644 index 000000000..c476b2c96 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.html @@ -0,0 +1,327 @@ + + + + + + + +MLX: mlx::core::fast::RMSNorm Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::fast::RMSNorm Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::RMSNorm:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 RMSNorm (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
 DEFINE_PRINT (RMSNorm) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ RMSNorm()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::fast::RMSNorm::RMSNorm (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
float eps )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::RMSNorm::DEFINE_PRINT (RMSNorm ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RMSNorm::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RMSNorm::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::fast::RMSNorm::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::fast::Custom.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.png new file mode 100644 index 000000000..0cb8e0a31 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p-members.html new file mode 100644 index 000000000..826640f0b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p-members.html @@ -0,0 +1,109 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::RMSNormVJP Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::RMSNormVJP, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(RMSNormVJP) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::RMSNormVJP
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RMSNormVJPinlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RMSNormVJPvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
RMSNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)mlx::core::fast::RMSNormVJPinline
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::Customvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.html new file mode 100644 index 000000000..4b010e9d1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.html @@ -0,0 +1,284 @@ + + + + + + + +MLX: mlx::core::fast::RMSNormVJP Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::fast::RMSNormVJP Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::RMSNormVJP:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 RMSNormVJP (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
 DEFINE_PRINT (RMSNormVJP) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ RMSNormVJP()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::fast::RMSNormVJP::RMSNormVJP (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
float eps )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::RMSNormVJP::DEFINE_PRINT (RMSNormVJP ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RMSNormVJP::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RMSNormVJP::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.png new file mode 100644 index 000000000..39e2b0d04 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e-members.html new file mode 100644 index 000000000..2d2af6eab --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e-members.html @@ -0,0 +1,109 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::RoPE Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::RoPE, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::RoPE
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RoPEinlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RoPEvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
RoPE(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward)mlx::core::fast::RoPEinline
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::RoPEvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.html new file mode 100644 index 000000000..745b490be --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.html @@ -0,0 +1,352 @@ + + + + + + + +MLX: mlx::core::fast::RoPE Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::fast::RoPE Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::RoPE:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 RoPE (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
 DEFINE_PRINT (RoPE) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ RoPE()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
mlx::core::fast::RoPE::RoPE (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
int dims,
bool traditional,
float base,
float scale,
int offset,
bool forward )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::RoPE::DEFINE_PRINT (RoPE ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RoPE::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RoPE::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::fast::RoPE::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::fast::Custom.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.png new file mode 100644 index 000000000..62648d941 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention-members.html new file mode 100644 index 000000000..f75d5bd4d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention-members.html @@ -0,0 +1,110 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::ScaledDotProductAttention Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::ScaledDotProductAttention, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(ScaledDotProductAttention)mlx::core::fast::ScaledDotProductAttention
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::ScaledDotProductAttentioninlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::ScaledDotProductAttentioninlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out)mlx::core::fast::ScaledDotProductAttention
is_equivalent(const Primitive &other) const overridemlx::core::fast::ScaledDotProductAttentionvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
ScaledDotProductAttention(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale, const bool needs_mask)mlx::core::fast::ScaledDotProductAttentioninlineexplicit
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::Customvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.html new file mode 100644 index 000000000..e5b515790 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.html @@ -0,0 +1,333 @@ + + + + + + + +MLX: mlx::core::fast::ScaledDotProductAttention Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::fast::ScaledDotProductAttention Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::ScaledDotProductAttention:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ScaledDotProductAttention (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale, const bool needs_mask)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void eval_gpu (const std::vector< array > &inputs, array &out)
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
 DEFINE_PRINT (ScaledDotProductAttention)
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ScaledDotProductAttention()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::fast::ScaledDotProductAttention::ScaledDotProductAttention (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
const float scale,
const bool needs_mask )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + + + +
mlx::core::fast::ScaledDotProductAttention::DEFINE_PRINT (ScaledDotProductAttention )
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::ScaledDotProductAttention::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu() [1/2]

+ +
+
+ + + + + + + + + + + +
void mlx::core::fast::ScaledDotProductAttention::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+ +
+
+ +

◆ eval_gpu() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::ScaledDotProductAttention::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::fast::ScaledDotProductAttention::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.png new file mode 100644 index 000000000..65f61e4a0 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader-members.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader-members.html new file mode 100644 index 000000000..1c5df6b84 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader-members.html @@ -0,0 +1,98 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::io::FileReader Member List
+
+
+ +

This is the complete list of members for mlx::core::io::FileReader, including all inherited members.

+ + + + + + + + + +
FileReader(std::ifstream is)mlx::core::io::FileReaderinlineexplicit
FileReader(std::string file_path)mlx::core::io::FileReaderinlineexplicit
good() const overridemlx::core::io::FileReaderinlinevirtual
is_open() const overridemlx::core::io::FileReaderinlinevirtual
label() const overridemlx::core::io::FileReaderinlinevirtual
read(char *data, size_t n) overridemlx::core::io::FileReaderinlinevirtual
seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg) overridemlx::core::io::FileReaderinlinevirtual
tell() overridemlx::core::io::FileReaderinlinevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.html new file mode 100644 index 000000000..1ae86019b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.html @@ -0,0 +1,346 @@ + + + + + + + +MLX: mlx::core::io::FileReader Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::io::FileReader Class Reference
+
+
+ +

#include <load.h>

+
+Inheritance diagram for mlx::core::io::FileReader:
+
+
+ + +mlx::core::io::Reader + +
+ + + + + + + + + + + + + + + + + + +

+Public Member Functions

 FileReader (std::ifstream is)
 
 FileReader (std::string file_path)
 
bool is_open () const override
 
bool good () const override
 
size_t tell () override
 
void seek (int64_t off, std::ios_base::seekdir way=std::ios_base::beg) override
 
void read (char *data, size_t n) override
 
std::string label () const override
 
+

Constructor & Destructor Documentation

+ +

◆ FileReader() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::io::FileReader::FileReader (std::ifstream is)
+
+inlineexplicit
+
+ +
+
+ +

◆ FileReader() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::io::FileReader::FileReader (std::string file_path)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ good()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::io::FileReader::good () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ is_open()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::io::FileReader::is_open () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ label()

+ +
+
+ + + + + +
+ + + + + + + +
std::string mlx::core::io::FileReader::label () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ read()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::io::FileReader::read (char * data,
size_t n )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ seek()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::io::FileReader::seek (int64_t off,
std::ios_base::seekdir way = std::ios_base::beg )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ tell()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::io::FileReader::tell ()
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.png b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.png new file mode 100644 index 000000000..0a31de8a6 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer-members.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer-members.html new file mode 100644 index 000000000..5491f510a --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer-members.html @@ -0,0 +1,98 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::io::FileWriter Member List
+
+
+ +

This is the complete list of members for mlx::core::io::FileWriter, including all inherited members.

+ + + + + + + + + +
FileWriter(std::ofstream os)mlx::core::io::FileWriterinlineexplicit
FileWriter(std::string file_path)mlx::core::io::FileWriterinlineexplicit
good() const overridemlx::core::io::FileWriterinlinevirtual
is_open() const overridemlx::core::io::FileWriterinlinevirtual
label() const overridemlx::core::io::FileWriterinlinevirtual
seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg) overridemlx::core::io::FileWriterinlinevirtual
tell() overridemlx::core::io::FileWriterinlinevirtual
write(const char *data, size_t n) overridemlx::core::io::FileWriterinlinevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.html new file mode 100644 index 000000000..7367b3b87 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.html @@ -0,0 +1,346 @@ + + + + + + + +MLX: mlx::core::io::FileWriter Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::io::FileWriter Class Reference
+
+
+ +

#include <load.h>

+
+Inheritance diagram for mlx::core::io::FileWriter:
+
+
+ + +mlx::core::io::Writer + +
+ + + + + + + + + + + + + + + + + + +

+Public Member Functions

 FileWriter (std::ofstream os)
 
 FileWriter (std::string file_path)
 
bool is_open () const override
 
bool good () const override
 
size_t tell () override
 
void seek (int64_t off, std::ios_base::seekdir way=std::ios_base::beg) override
 
void write (const char *data, size_t n) override
 
std::string label () const override
 
+

Constructor & Destructor Documentation

+ +

◆ FileWriter() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::io::FileWriter::FileWriter (std::ofstream os)
+
+inlineexplicit
+
+ +
+
+ +

◆ FileWriter() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::io::FileWriter::FileWriter (std::string file_path)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ good()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::io::FileWriter::good () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ is_open()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::io::FileWriter::is_open () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ label()

+ +
+
+ + + + + +
+ + + + + + + +
std::string mlx::core::io::FileWriter::label () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ seek()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::io::FileWriter::seek (int64_t off,
std::ios_base::seekdir way = std::ios_base::beg )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ tell()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::io::FileWriter::tell ()
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ write()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::io::FileWriter::write (const char * data,
size_t n )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.png b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.png new file mode 100644 index 000000000..3f1679897 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_reader-members.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader-members.html new file mode 100644 index 000000000..98887ab2b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader-members.html @@ -0,0 +1,96 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::io::Reader Member List
+
+
+ +

This is the complete list of members for mlx::core::io::Reader, including all inherited members.

+ + + + + + + +
good() const =0mlx::core::io::Readerpure virtual
is_open() const =0mlx::core::io::Readerpure virtual
label() const =0mlx::core::io::Readerpure virtual
read(char *data, size_t n)=0mlx::core::io::Readerpure virtual
seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0mlx::core::io::Readerpure virtual
tell()=0mlx::core::io::Readerpure virtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.html new file mode 100644 index 000000000..2cd2194a8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.html @@ -0,0 +1,291 @@ + + + + + + + +MLX: mlx::core::io::Reader Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::io::Reader Class Referenceabstract
+
+
+ +

#include <load.h>

+
+Inheritance diagram for mlx::core::io::Reader:
+
+
+ + +mlx::core::io::FileReader + +
+ + + + + + + + + + + + + + +

+Public Member Functions

virtual bool is_open () const =0
 
virtual bool good () const =0
 
virtual size_t tell ()=0
 
virtual void seek (int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
 
virtual void read (char *data, size_t n)=0
 
virtual std::string label () const =0
 
+

Member Function Documentation

+ +

◆ good()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::io::Reader::good () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ is_open()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::io::Reader::is_open () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ label()

+ +
+
+ + + + + +
+ + + + + + + +
virtual std::string mlx::core::io::Reader::label () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ read()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::io::Reader::read (char * data,
size_t n )
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ seek()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::io::Reader::seek (int64_t off,
std::ios_base::seekdir way = std::ios_base::beg )
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ tell()

+ +
+
+ + + + + +
+ + + + + + + +
virtual size_t mlx::core::io::Reader::tell ()
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.png b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.png new file mode 100644 index 000000000..a28b37482 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_writer-members.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_writer-members.html new file mode 100644 index 000000000..2ca6aa3bc --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_writer-members.html @@ -0,0 +1,96 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::io::Writer Member List
+
+
+ +

This is the complete list of members for mlx::core::io::Writer, including all inherited members.

+ + + + + + + +
good() const =0mlx::core::io::Writerpure virtual
is_open() const =0mlx::core::io::Writerpure virtual
label() const =0mlx::core::io::Writerpure virtual
seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0mlx::core::io::Writerpure virtual
tell()=0mlx::core::io::Writerpure virtual
write(const char *data, size_t n)=0mlx::core::io::Writerpure virtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.html new file mode 100644 index 000000000..991f1a823 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.html @@ -0,0 +1,291 @@ + + + + + + + +MLX: mlx::core::io::Writer Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::io::Writer Class Referenceabstract
+
+
+ +

#include <load.h>

+
+Inheritance diagram for mlx::core::io::Writer:
+
+
+ + +mlx::core::io::FileWriter + +
+ + + + + + + + + + + + + + +

+Public Member Functions

virtual bool is_open () const =0
 
virtual bool good () const =0
 
virtual size_t tell ()=0
 
virtual void seek (int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
 
virtual void write (const char *data, size_t n)=0
 
virtual std::string label () const =0
 
+

Member Function Documentation

+ +

◆ good()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::io::Writer::good () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ is_open()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::io::Writer::is_open () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ label()

+ +
+
+ + + + + +
+ + + + + + + +
virtual std::string mlx::core::io::Writer::label () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ seek()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::io::Writer::seek (int64_t off,
std::ios_base::seekdir way = std::ios_base::beg )
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ tell()

+ +
+
+ + + + + +
+ + + + + + + +
virtual size_t mlx::core::io::Writer::tell ()
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ write()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::io::Writer::write (const char * data,
size_t n )
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.png b/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.png new file mode 100644 index 000000000..70dfa5f68 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_device-members.html b/docs/build/html/classmlx_1_1core_1_1metal_1_1_device-members.html new file mode 100644 index 000000000..0d522e267 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1metal_1_1_device-members.html @@ -0,0 +1,112 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::metal::Device Member List
+
+
+ +

This is the complete list of members for mlx::core::metal::Device, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + +
argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) constmlx::core::metal::Device
commit_command_buffer(int index)mlx::core::metal::Device
Device()mlx::core::metal::Device
Device(const Device &)=deletemlx::core::metal::Device
end_encoding(int index)mlx::core::metal::Device
get_command_buffer(int index)mlx::core::metal::Device
get_command_buffer_ops(int index)mlx::core::metal::Device
get_command_encoder(int index)mlx::core::metal::Device
get_function(const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})mlx::core::metal::Device
get_function(const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})mlx::core::metal::Device
get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})mlx::core::metal::Device
get_kernel(const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})mlx::core::metal::Device
get_library(const std::string &name)mlx::core::metal::Device
get_library(const std::string &name, const std::string &source_string, bool cache=true)mlx::core::metal::Device
get_library(const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)mlx::core::metal::Device
increment_command_buffer_ops(int index)mlx::core::metal::Device
mtl_device()mlx::core::metal::Deviceinline
new_queue(int index)mlx::core::metal::Device
operator=(const Device &)=deletemlx::core::metal::Device
register_library(const std::string &lib_name, const std::string &lib_path)mlx::core::metal::Device
register_library(const std::string &lib_name, const std::function< std::string(const std::string &)> &lib_path_func=get_colocated_mtllib_path)mlx::core::metal::Device
~Device()mlx::core::metal::Device
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_device.html b/docs/build/html/classmlx_1_1core_1_1metal_1_1_device.html new file mode 100644 index 000000000..9d6e636ad --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1metal_1_1_device.html @@ -0,0 +1,635 @@ + + + + + + + +MLX: mlx::core::metal::Device Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::metal::Device Class Reference
+
+
+ +

#include <device.h>

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Device ()
 
 Device (const Device &)=delete
 
Deviceoperator= (const Device &)=delete
 
 ~Device ()
 
MTL::Device * mtl_device ()
 
void new_queue (int index)
 
MTL::CommandBuffer * get_command_buffer (int index)
 
int get_command_buffer_ops (int index)
 
void increment_command_buffer_ops (int index)
 
void commit_command_buffer (int index)
 
CommandEncoderget_command_encoder (int index)
 
void end_encoding (int index)
 
void register_library (const std::string &lib_name, const std::string &lib_path)
 
void register_library (const std::string &lib_name, const std::function< std::string(const std::string &)> &lib_path_func=get_colocated_mtllib_path)
 
MTL::Library * get_library (const std::string &name)
 
MTL::Library * get_library (const std::string &name, const std::string &source_string, bool cache=true)
 
MTL::Library * get_library (const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)
 
MTL::Function * get_function (const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})
 
MTL::Function * get_function (const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})
 
MTL::ComputePipelineState * get_kernel (const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
 
MTL::ComputePipelineState * get_kernel (const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
 
MTL::ArgumentEncoder * argument_encoder (const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
 
+

Constructor & Destructor Documentation

+ +

◆ Device() [1/2]

+ +
+
+ + + + + + + +
mlx::core::metal::Device::Device ()
+
+ +
+
+ +

◆ Device() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::metal::Device::Device (const Device & )
+
+delete
+
+ +
+
+ +

◆ ~Device()

+ +
+
+ + + + + + + +
mlx::core::metal::Device::~Device ()
+
+ +
+
+

Member Function Documentation

+ +

◆ argument_encoder()

+ +
+
+ + + + + + + +
MTL::ArgumentEncoder * mlx::core::metal::Device::argument_encoder (const std::vector< MTL::ArgumentDescriptor * > & arg_descs) const
+
+ +
+
+ +

◆ commit_command_buffer()

+ +
+
+ + + + + + + +
void mlx::core::metal::Device::commit_command_buffer (int index)
+
+ +
+
+ +

◆ end_encoding()

+ +
+
+ + + + + + + +
void mlx::core::metal::Device::end_encoding (int index)
+
+ +
+
+ +

◆ get_command_buffer()

+ +
+
+ + + + + + + +
MTL::CommandBuffer * mlx::core::metal::Device::get_command_buffer (int index)
+
+ +
+
+ +

◆ get_command_buffer_ops()

+ +
+
+ + + + + + + +
int mlx::core::metal::Device::get_command_buffer_ops (int index)
+
+ +
+
+ +

◆ get_command_encoder()

+ +
+
+ + + + + + + +
CommandEncoder & mlx::core::metal::Device::get_command_encoder (int index)
+
+ +
+
+ +

◆ get_function() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
MTL::Function * mlx::core::metal::Device::get_function (const std::string & base_name,
const std::string & lib_name = "mlx",
const std::string & specialized_name = "",
const MTLFCList & func_consts = {} )
+
+ +
+
+ +

◆ get_function() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
MTL::Function * mlx::core::metal::Device::get_function (const std::string & base_name,
MTL::Library * mtl_lib,
const std::string & specialized_name = "",
const MTLFCList & func_consts = {} )
+
+ +
+
+ +

◆ get_kernel() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
MTL::ComputePipelineState * mlx::core::metal::Device::get_kernel (const std::string & base_name,
const std::string & lib_name = "mlx",
const std::string & hash_name = "",
const MTLFCList & func_consts = {},
const std::vector< MTL::Function * > & linked_functions = {} )
+
+ +
+
+ +

◆ get_kernel() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
MTL::ComputePipelineState * mlx::core::metal::Device::get_kernel (const std::string & base_name,
MTL::Library * mtl_lib,
const std::string & hash_name = "",
const MTLFCList & func_consts = {},
const std::vector< MTL::Function * > & linked_functions = {} )
+
+ +
+
+ +

◆ get_library() [1/3]

+ +
+
+ + + + + + + +
MTL::Library * mlx::core::metal::Device::get_library (const std::string & name)
+
+ +
+
+ +

◆ get_library() [2/3]

+ +
+
+ + + + + + + + + + + + + + + + +
MTL::Library * mlx::core::metal::Device::get_library (const std::string & name,
const MTL::StitchedLibraryDescriptor * desc,
bool cache = true )
+
+ +
+
+ +

◆ get_library() [3/3]

+ +
+
+ + + + + + + + + + + + + + + + +
MTL::Library * mlx::core::metal::Device::get_library (const std::string & name,
const std::string & source_string,
bool cache = true )
+
+ +
+
+ +

◆ increment_command_buffer_ops()

+ +
+
+ + + + + + + +
void mlx::core::metal::Device::increment_command_buffer_ops (int index)
+
+ +
+
+ +

◆ mtl_device()

+ +
+
+ + + + + +
+ + + + + + + +
MTL::Device * mlx::core::metal::Device::mtl_device ()
+
+inline
+
+ +
+
+ +

◆ new_queue()

+ +
+
+ + + + + + + +
void mlx::core::metal::Device::new_queue (int index)
+
+ +
+
+ +

◆ operator=()

+ +
+
+ + + + + +
+ + + + + + + +
Device & mlx::core::metal::Device::operator= (const Device & )
+
+delete
+
+ +
+
+ +

◆ register_library() [1/2]

+ +
+
+ + + + + + + + + + + +
void mlx::core::metal::Device::register_library (const std::string & lib_name,
const std::function< std::string(const std::string &)> & lib_path_func = get_colocated_mtllib_path )
+
+ +
+
+ +

◆ register_library() [2/2]

+ +
+
+ + + + + + + + + + + +
void mlx::core::metal::Device::register_library (const std::string & lib_name,
const std::string & lib_path )
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator-members.html b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator-members.html new file mode 100644 index 000000000..6ddf10230 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator-members.html @@ -0,0 +1,106 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::metal::MetalAllocator Member List
+
+
+ +

This is the complete list of members for mlx::core::metal::MetalAllocator, including all inherited members.

+ + + + + + + + + + + + + + + + + +
allocatormlx::core::metal::MetalAllocatorfriend
Allocator()=defaultmlx::core::allocator::Allocator
Allocator(const Allocator &other)=deletemlx::core::allocator::Allocator
Allocator(Allocator &&other)=deletemlx::core::allocator::Allocator
clear_cache()mlx::core::metal::MetalAllocator
free(Buffer buffer) overridemlx::core::metal::MetalAllocatorvirtual
get_active_memory()mlx::core::metal::MetalAllocatorinline
get_cache_memory()mlx::core::metal::MetalAllocatorinline
get_peak_memory()mlx::core::metal::MetalAllocatorinline
malloc(size_t size, bool allow_swap=false) overridemlx::core::metal::MetalAllocatorvirtual
operator=(const Allocator &other)=deletemlx::core::allocator::Allocator
operator=(Allocator &&other)=deletemlx::core::allocator::Allocator
reset_peak_memory()mlx::core::metal::MetalAllocatorinline
set_cache_limit(size_t limit)mlx::core::metal::MetalAllocator
set_memory_limit(size_t limit, bool relaxed)mlx::core::metal::MetalAllocator
~Allocator()=defaultmlx::core::allocator::Allocatorvirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.html b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.html new file mode 100644 index 000000000..b19a22a49 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.html @@ -0,0 +1,388 @@ + + + + + + + +MLX: mlx::core::metal::MetalAllocator Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Friends | +List of all members
+
mlx::core::metal::MetalAllocator Class Reference
+
+
+ +

#include <allocator.h>

+
+Inheritance diagram for mlx::core::metal::MetalAllocator:
+
+
+ + +mlx::core::allocator::Allocator + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

virtual Buffer malloc (size_t size, bool allow_swap=false) override
 Allocator for Metal GPUs.
 
virtual void free (Buffer buffer) override
 
size_t get_active_memory ()
 
size_t get_peak_memory ()
 
void reset_peak_memory ()
 
size_t get_cache_memory ()
 
size_t set_cache_limit (size_t limit)
 
size_t set_memory_limit (size_t limit, bool relaxed)
 
void clear_cache ()
 
- Public Member Functions inherited from mlx::core::allocator::Allocator
 Allocator ()=default
 
 Allocator (const Allocator &other)=delete
 
 Allocator (Allocator &&other)=delete
 
Allocatoroperator= (const Allocator &other)=delete
 
Allocatoroperator= (Allocator &&other)=delete
 
virtual ~Allocator ()=default
 
+ + + +

+Friends

MetalAllocatorallocator ()
 
+

Member Function Documentation

+ +

◆ clear_cache()

+ +
+
+ + + + + + + +
void mlx::core::metal::MetalAllocator::clear_cache ()
+
+ +
+
+ +

◆ free()

+ +
+
+ + + + + +
+ + + + + + + +
virtual void mlx::core::metal::MetalAllocator::free (Buffer buffer)
+
+overridevirtual
+
+
+ +

◆ get_active_memory()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::metal::MetalAllocator::get_active_memory ()
+
+inline
+
+ +
+
+ +

◆ get_cache_memory()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::metal::MetalAllocator::get_cache_memory ()
+
+inline
+
+ +
+
+ +

◆ get_peak_memory()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::metal::MetalAllocator::get_peak_memory ()
+
+inline
+
+ +
+
+ +

◆ malloc()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual Buffer mlx::core::metal::MetalAllocator::malloc (size_t size,
bool allow_swap = false )
+
+overridevirtual
+
+ +

Allocator for Metal GPUs.

+ +

Implements mlx::core::allocator::Allocator.

+ +
+
+ +

◆ reset_peak_memory()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::metal::MetalAllocator::reset_peak_memory ()
+
+inline
+
+ +
+
+ +

◆ set_cache_limit()

+ +
+
+ + + + + + + +
size_t mlx::core::metal::MetalAllocator::set_cache_limit (size_t limit)
+
+ +
+
+ +

◆ set_memory_limit()

+ +
+
+ + + + + + + + + + + +
size_t mlx::core::metal::MetalAllocator::set_memory_limit (size_t limit,
bool relaxed )
+
+ +
+
+

Friends And Related Symbol Documentation

+ +

◆ allocator

+ +
+
+ + + + + +
+ + + + + + + +
MetalAllocator & allocator ()
+
+friend
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.png b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.png new file mode 100644 index 000000000..c82190d62 Binary files /dev/null and b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.png differ diff --git a/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence-members.html b/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence-members.html new file mode 100644 index 000000000..a3e82b9eb --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence-members.html @@ -0,0 +1,94 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::random::KeySequence Member List
+
+
+ +

This is the complete list of members for mlx::core::random::KeySequence, including all inherited members.

+ + + + + +
default_()mlx::core::random::KeySequenceinlinestatic
KeySequence(uint64_t seed)mlx::core::random::KeySequenceexplicit
next()mlx::core::random::KeySequence
seed(uint64_t seed)mlx::core::random::KeySequence
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence.html b/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence.html new file mode 100644 index 000000000..4662bb44f --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence.html @@ -0,0 +1,197 @@ + + + + + + + +MLX: mlx::core::random::KeySequence Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Static Public Member Functions | +List of all members
+
mlx::core::random::KeySequence Class Reference
+
+
+ +

#include <random.h>

+ + + + + + + + +

+Public Member Functions

 KeySequence (uint64_t seed)
 
void seed (uint64_t seed)
 
array next ()
 
+ + + +

+Static Public Member Functions

static KeySequencedefault_ ()
 
+

Constructor & Destructor Documentation

+ +

◆ KeySequence()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::random::KeySequence::KeySequence (uint64_t seed)
+
+explicit
+
+ +
+
+

Member Function Documentation

+ +

◆ default_()

+ +
+
+ + + + + +
+ + + + + + + +
static KeySequence & mlx::core::random::KeySequence::default_ ()
+
+inlinestatic
+
+ +
+
+ +

◆ next()

+ +
+
+ + + + + + + +
array mlx::core::random::KeySequence::next ()
+
+ +
+
+ +

◆ seed()

+ +
+
+ + + + + + + +
void mlx::core::random::KeySequence::seed (uint64_t seed)
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler-members.html b/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler-members.html new file mode 100644 index 000000000..7a145e396 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler-members.html @@ -0,0 +1,104 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::scheduler::Scheduler Member List
+
+
+ +

This is the complete list of members for mlx::core::scheduler::Scheduler, including all inherited members.

+ + + + + + + + + + + + + + + +
enqueue(const Stream &stream, F &&f)mlx::core::scheduler::Scheduler
get_default_stream(const Device &d)mlx::core::scheduler::Schedulerinline
n_active_tasks() constmlx::core::scheduler::Schedulerinline
new_stream(const Device &d)mlx::core::scheduler::Schedulerinline
notify_new_task(const Stream &stream)mlx::core::scheduler::Schedulerinline
notify_task_completion(const Stream &stream)mlx::core::scheduler::Schedulerinline
operator=(const Scheduler &)=deletemlx::core::scheduler::Scheduler
operator=(Scheduler &&)=deletemlx::core::scheduler::Scheduler
Scheduler()mlx::core::scheduler::Schedulerinline
Scheduler(const Scheduler &)=deletemlx::core::scheduler::Scheduler
Scheduler(Scheduler &&)=deletemlx::core::scheduler::Scheduler
set_default_stream(const Stream &s)mlx::core::scheduler::Schedulerinline
wait_for_one()mlx::core::scheduler::Schedulerinline
~Scheduler()mlx::core::scheduler::Schedulerinline
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler.html b/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler.html new file mode 100644 index 000000000..b9ec8053b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler.html @@ -0,0 +1,478 @@ + + + + + + + +MLX: mlx::core::scheduler::Scheduler Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
mlx::core::scheduler::Scheduler Class Reference
+
+
+ +

#include <scheduler.h>

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Scheduler ()
 
 Scheduler (const Scheduler &)=delete
 
 Scheduler (Scheduler &&)=delete
 
Scheduleroperator= (const Scheduler &)=delete
 
Scheduleroperator= (Scheduler &&)=delete
 
Stream new_stream (const Device &d)
 
template<typename F >
void enqueue (const Stream &stream, F &&f)
 
Stream get_default_stream (const Device &d)
 
void set_default_stream (const Stream &s)
 
void notify_new_task (const Stream &stream)
 
void notify_task_completion (const Stream &stream)
 
int n_active_tasks () const
 
void wait_for_one ()
 
 ~Scheduler ()
 
+

Constructor & Destructor Documentation

+ +

◆ Scheduler() [1/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::scheduler::Scheduler::Scheduler ()
+
+inline
+
+ +
+
+ +

◆ Scheduler() [2/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::scheduler::Scheduler::Scheduler (const Scheduler & )
+
+delete
+
+ +
+
+ +

◆ Scheduler() [3/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::scheduler::Scheduler::Scheduler (Scheduler && )
+
+delete
+
+ +
+
+ +

◆ ~Scheduler()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::scheduler::Scheduler::~Scheduler ()
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ enqueue()

+ +
+
+
+template<typename F >
+ + + + + + + + + + + +
void mlx::core::scheduler::Scheduler::enqueue (const Stream & stream,
F && f )
+
+ +
+
+ +

◆ get_default_stream()

+ +
+
+ + + + + +
+ + + + + + + +
Stream mlx::core::scheduler::Scheduler::get_default_stream (const Device & d)
+
+inline
+
+ +
+
+ +

◆ n_active_tasks()

+ +
+
+ + + + + +
+ + + + + + + +
int mlx::core::scheduler::Scheduler::n_active_tasks () const
+
+inline
+
+ +
+
+ +

◆ new_stream()

+ +
+
+ + + + + +
+ + + + + + + +
Stream mlx::core::scheduler::Scheduler::new_stream (const Device & d)
+
+inline
+
+ +
+
+ +

◆ notify_new_task()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::scheduler::Scheduler::notify_new_task (const Stream & stream)
+
+inline
+
+ +
+
+ +

◆ notify_task_completion()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::scheduler::Scheduler::notify_task_completion (const Stream & stream)
+
+inline
+
+ +
+
+ +

◆ operator=() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
Scheduler & mlx::core::scheduler::Scheduler::operator= (const Scheduler & )
+
+delete
+
+ +
+
+ +

◆ operator=() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
Scheduler & mlx::core::scheduler::Scheduler::operator= (Scheduler && )
+
+delete
+
+ +
+
+ +

◆ set_default_stream()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::scheduler::Scheduler::set_default_stream (const Stream & s)
+
+inline
+
+ +
+
+ +

◆ wait_for_one()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::scheduler::Scheduler::wait_for_one ()
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23-members.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23-members.html new file mode 100644 index 000000000..91ebfde05 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::T_dcst23< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::T_dcst23< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool ortho, int type, bool cosine) constpocketfft::detail::T_dcst23< T0 >inline
length() constpocketfft::detail::T_dcst23< T0 >inline
T_dcst23(size_t length)pocketfft::detail::T_dcst23< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23.html new file mode 100644 index 000000000..c7c55e2aa --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23.html @@ -0,0 +1,210 @@ + + + + + + + +MLX: pocketfft::detail::T_dcst23< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::T_dcst23< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 T_dcst23 (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool ortho, int type, bool cosine) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ T_dcst23()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::T_dcst23< T0 >::T_dcst23 (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void pocketfft::detail::T_dcst23< T0 >::exec (T c[],
T0 fct,
bool ortho,
int type,
bool cosine ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::T_dcst23< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4-members.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4-members.html new file mode 100644 index 000000000..3df2ebce9 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::T_dcst4< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::T_dcst4< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool, int, bool cosine) constpocketfft::detail::T_dcst4< T0 >inline
length() constpocketfft::detail::T_dcst4< T0 >inline
T_dcst4(size_t length)pocketfft::detail::T_dcst4< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4.html new file mode 100644 index 000000000..8d03410d4 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4.html @@ -0,0 +1,210 @@ + + + + + + + +MLX: pocketfft::detail::T_dcst4< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::T_dcst4< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 T_dcst4 (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool, int, bool cosine) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ T_dcst4()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::T_dcst4< T0 >::T_dcst4 (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void pocketfft::detail::T_dcst4< T0 >::exec (T c[],
T0 fct,
bool ,
int ,
bool cosine ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::T_dcst4< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1-members.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1-members.html new file mode 100644 index 000000000..306b02f42 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::T_dct1< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::T_dct1< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool ortho, int, bool) constpocketfft::detail::T_dct1< T0 >inline
length() constpocketfft::detail::T_dct1< T0 >inline
T_dct1(size_t length)pocketfft::detail::T_dct1< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1.html new file mode 100644 index 000000000..ad8c292fc --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1.html @@ -0,0 +1,210 @@ + + + + + + + +MLX: pocketfft::detail::T_dct1< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::T_dct1< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 T_dct1 (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool ortho, int, bool) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ T_dct1()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::T_dct1< T0 >::T_dct1 (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void pocketfft::detail::T_dct1< T0 >::exec (T c[],
T0 fct,
bool ortho,
int ,
bool  ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::T_dct1< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1-members.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1-members.html new file mode 100644 index 000000000..5f72dd266 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::T_dst1< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::T_dst1< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool, int, bool) constpocketfft::detail::T_dst1< T0 >inline
length() constpocketfft::detail::T_dst1< T0 >inline
T_dst1(size_t length)pocketfft::detail::T_dst1< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1.html new file mode 100644 index 000000000..678dde855 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1.html @@ -0,0 +1,210 @@ + + + + + + + +MLX: pocketfft::detail::T_dst1< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::T_dst1< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 T_dst1 (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool, int, bool) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ T_dst1()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::T_dst1< T0 >::T_dst1 (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void pocketfft::detail::T_dst1< T0 >::exec (T c[],
T0 fct,
bool ,
int ,
bool  ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::T_dst1< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr-members.html b/docs/build/html/classpocketfft_1_1detail_1_1arr-members.html new file mode 100644 index 000000000..0a83715c7 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1arr-members.html @@ -0,0 +1,100 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::arr< T > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::arr< T >, including all inherited members.

+ + + + + + + + + + + +
arr()pocketfft::detail::arr< T >inline
arr(size_t n)pocketfft::detail::arr< T >inline
arr(arr &&other)pocketfft::detail::arr< T >inline
data()pocketfft::detail::arr< T >inline
data() constpocketfft::detail::arr< T >inline
operator[](size_t idx)pocketfft::detail::arr< T >inline
operator[](size_t idx) constpocketfft::detail::arr< T >inline
resize(size_t n)pocketfft::detail::arr< T >inline
size() constpocketfft::detail::arr< T >inline
~arr()pocketfft::detail::arr< T >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr.html b/docs/build/html/classpocketfft_1_1detail_1_1arr.html new file mode 100644 index 000000000..d81388ba2 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1arr.html @@ -0,0 +1,391 @@ + + + + + + + +MLX: pocketfft::detail::arr< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::arr< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 arr ()
 
 arr (size_t n)
 
 arr (arr &&other)
 
 ~arr ()
 
void resize (size_t n)
 
Toperator[] (size_t idx)
 
const Toperator[] (size_t idx) const
 
Tdata ()
 
const Tdata () const
 
size_t size () const
 
+

Constructor & Destructor Documentation

+ +

◆ arr() [1/3]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::arr< T >::arr ()
+
+inline
+
+ +
+
+ +

◆ arr() [2/3]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::arr< T >::arr (size_t n)
+
+inline
+
+ +
+
+ +

◆ arr() [3/3]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::arr< T >::arr (arr< T > && other)
+
+inline
+
+ +
+
+ +

◆ ~arr()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::arr< T >::~arr ()
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ data() [1/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
T * pocketfft::detail::arr< T >::data ()
+
+inline
+
+ +
+
+ +

◆ data() [2/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
const T * pocketfft::detail::arr< T >::data () const
+
+inline
+
+ +
+
+ +

◆ operator[]() [1/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
T & pocketfft::detail::arr< T >::operator[] (size_t idx)
+
+inline
+
+ +
+
+ +

◆ operator[]() [2/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
const T & pocketfft::detail::arr< T >::operator[] (size_t idx) const
+
+inline
+
+ +
+
+ +

◆ resize()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
void pocketfft::detail::arr< T >::resize (size_t n)
+
+inline
+
+ +
+
+ +

◆ size()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::arr< T >::size () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr__info-members.html b/docs/build/html/classpocketfft_1_1detail_1_1arr__info-members.html new file mode 100644 index 000000000..38eb5333c --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1arr__info-members.html @@ -0,0 +1,99 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::arr_info Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::arr_info, including all inherited members.

+ + + + + + + + + + +
arr_info(const shape_t &shape_, const stride_t &stride_)pocketfft::detail::arr_infoinline
ndim() constpocketfft::detail::arr_infoinline
shape() constpocketfft::detail::arr_infoinline
shape(size_t i) constpocketfft::detail::arr_infoinline
shppocketfft::detail::arr_infoprotected
size() constpocketfft::detail::arr_infoinline
strpocketfft::detail::arr_infoprotected
stride() constpocketfft::detail::arr_infoinline
stride(size_t i) constpocketfft::detail::arr_infoinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr__info.html b/docs/build/html/classpocketfft_1_1detail_1_1arr__info.html new file mode 100644 index 000000000..08df02140 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1arr__info.html @@ -0,0 +1,357 @@ + + + + + + + +MLX: pocketfft::detail::arr_info Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Protected Attributes | +List of all members
+
pocketfft::detail::arr_info Class Reference
+
+
+ +

#include <pocketfft.h>

+
+Inheritance diagram for pocketfft::detail::arr_info:
+
+
+ + +pocketfft::detail::cndarr< T > +pocketfft::detail::ndarr< T > + +
+ + + + + + + + + + + + + + + + +

+Public Member Functions

 arr_info (const shape_t &shape_, const stride_t &stride_)
 
size_t ndim () const
 
size_t size () const
 
const shape_tshape () const
 
size_t shape (size_t i) const
 
const stride_tstride () const
 
const ptrdiff_t & stride (size_t i) const
 
+ + + + + +

+Protected Attributes

shape_t shp
 
stride_t str
 
+

Constructor & Destructor Documentation

+ +

◆ arr_info()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
pocketfft::detail::arr_info::arr_info (const shape_t & shape_,
const stride_t & stride_ )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ ndim()

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::arr_info::ndim () const
+
+inline
+
+ +
+
+ +

◆ shape() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
const shape_t & pocketfft::detail::arr_info::shape () const
+
+inline
+
+ +
+
+ +

◆ shape() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::arr_info::shape (size_t i) const
+
+inline
+
+ +
+
+ +

◆ size()

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::arr_info::size () const
+
+inline
+
+ +
+
+ +

◆ stride() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
const stride_t & pocketfft::detail::arr_info::stride () const
+
+inline
+
+ +
+
+ +

◆ stride() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const ptrdiff_t & pocketfft::detail::arr_info::stride (size_t i) const
+
+inline
+
+ +
+
+

Member Data Documentation

+ +

◆ shp

+ +
+
+ + + + + +
+ + + + +
shape_t pocketfft::detail::arr_info::shp
+
+protected
+
+ +
+
+ +

◆ str

+ +
+
+ + + + + +
+ + + + +
stride_t pocketfft::detail::arr_info::str
+
+protected
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr__info.png b/docs/build/html/classpocketfft_1_1detail_1_1arr__info.png new file mode 100644 index 000000000..8cf1d3ced Binary files /dev/null and b/docs/build/html/classpocketfft_1_1detail_1_1arr__info.png differ diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cfftp-members.html b/docs/build/html/classpocketfft_1_1detail_1_1cfftp-members.html new file mode 100644 index 000000000..65676b3eb --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1cfftp-members.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::cfftp< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::cfftp< T0 >, including all inherited members.

+ + + +
cfftp(size_t length_)pocketfft::detail::cfftp< T0 >inline
exec(T c[], T0 fct, bool fwd) constpocketfft::detail::cfftp< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cfftp.html b/docs/build/html/classpocketfft_1_1detail_1_1cfftp.html new file mode 100644 index 000000000..77e4017fe --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1cfftp.html @@ -0,0 +1,171 @@ + + + + + + + +MLX: pocketfft::detail::cfftp< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::cfftp< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + +

+Public Member Functions

template<typename T >
void exec (T c[], T0 fct, bool fwd) const
 
 cfftp (size_t length_)
 
+

Constructor & Destructor Documentation

+ +

◆ cfftp()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::cfftp< T0 >::cfftp (size_t length_)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::cfftp< T0 >::exec (T c[],
T0 fct,
bool fwd ) const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cndarr-members.html b/docs/build/html/classpocketfft_1_1detail_1_1cndarr-members.html new file mode 100644 index 000000000..1d2ac7085 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1cndarr-members.html @@ -0,0 +1,102 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::cndarr< T > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::cndarr< T >, including all inherited members.

+ + + + + + + + + + + + + +
arr_info(const shape_t &shape_, const stride_t &stride_)pocketfft::detail::arr_infoinline
cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)pocketfft::detail::cndarr< T >inline
dpocketfft::detail::cndarr< T >protected
ndim() constpocketfft::detail::arr_infoinline
operator[](ptrdiff_t ofs) constpocketfft::detail::cndarr< T >inline
shape() constpocketfft::detail::arr_infoinline
shape(size_t i) constpocketfft::detail::arr_infoinline
shppocketfft::detail::arr_infoprotected
size() constpocketfft::detail::arr_infoinline
strpocketfft::detail::arr_infoprotected
stride() constpocketfft::detail::arr_infoinline
stride(size_t i) constpocketfft::detail::arr_infoinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cndarr.html b/docs/build/html/classpocketfft_1_1detail_1_1cndarr.html new file mode 100644 index 000000000..742ba67f6 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1cndarr.html @@ -0,0 +1,229 @@ + + + + + + + +MLX: pocketfft::detail::cndarr< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +Protected Attributes | +List of all members
+
pocketfft::detail::cndarr< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+
+Inheritance diagram for pocketfft::detail::cndarr< T >:
+
+
+ + +pocketfft::detail::arr_info +pocketfft::detail::ndarr< T > + +
+ + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 cndarr (const void *data_, const shape_t &shape_, const stride_t &stride_)
 
const Toperator[] (ptrdiff_t ofs) const
 
- Public Member Functions inherited from pocketfft::detail::arr_info
 arr_info (const shape_t &shape_, const stride_t &stride_)
 
size_t ndim () const
 
size_t size () const
 
const shape_tshape () const
 
size_t shape (size_t i) const
 
const stride_tstride () const
 
const ptrdiff_t & stride (size_t i) const
 
+ + + + + + + + +

+Protected Attributes

const chard
 
- Protected Attributes inherited from pocketfft::detail::arr_info
shape_t shp
 
stride_t str
 
+

Constructor & Destructor Documentation

+ +

◆ cndarr()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
pocketfft::detail::cndarr< T >::cndarr (const void * data_,
const shape_t & shape_,
const stride_t & stride_ )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ operator[]()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
const T & pocketfft::detail::cndarr< T >::operator[] (ptrdiff_t ofs) const
+
+inline
+
+ +
+
+

Member Data Documentation

+ +

◆ d

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
const char* pocketfft::detail::cndarr< T >::d
+
+protected
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cndarr.png b/docs/build/html/classpocketfft_1_1detail_1_1cndarr.png new file mode 100644 index 000000000..268d77caf Binary files /dev/null and b/docs/build/html/classpocketfft_1_1detail_1_1cndarr.png differ diff --git a/docs/build/html/classpocketfft_1_1detail_1_1fftblue-members.html b/docs/build/html/classpocketfft_1_1detail_1_1fftblue-members.html new file mode 100644 index 000000000..ba6b7bcdd --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1fftblue-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::fftblue< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::fftblue< T0 >, including all inherited members.

+ + + + +
exec(cmplx< T > c[], T0 fct, bool fwd) constpocketfft::detail::fftblue< T0 >inline
exec_r(T c[], T0 fct, bool fwd)pocketfft::detail::fftblue< T0 >inline
fftblue(size_t length)pocketfft::detail::fftblue< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1fftblue.html b/docs/build/html/classpocketfft_1_1detail_1_1fftblue.html new file mode 100644 index 000000000..519a05ae5 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1fftblue.html @@ -0,0 +1,212 @@ + + + + + + + +MLX: pocketfft::detail::fftblue< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::fftblue< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + +

+Public Member Functions

 fftblue (size_t length)
 
template<typename T >
void exec (cmplx< T > c[], T0 fct, bool fwd) const
 
template<typename T >
void exec_r (T c[], T0 fct, bool fwd)
 
+

Constructor & Destructor Documentation

+ +

◆ fftblue()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::fftblue< T0 >::fftblue (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::fftblue< T0 >::exec (cmplx< T > c[],
T0 fct,
bool fwd ) const
+
+inline
+
+ +
+
+ +

◆ exec_r()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::fftblue< T0 >::exec_r (T c[],
T0 fct,
bool fwd )
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1multi__iter-members.html b/docs/build/html/classpocketfft_1_1detail_1_1multi__iter-members.html new file mode 100644 index 000000000..fc671bdf6 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1multi__iter-members.html @@ -0,0 +1,101 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::multi_iter< N > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::multi_iter< N >, including all inherited members.

+ + + + + + + + + + + + +
advance(size_t n)pocketfft::detail::multi_iter< N >inline
iofs(size_t i) constpocketfft::detail::multi_iter< N >inline
iofs(size_t j, size_t i) constpocketfft::detail::multi_iter< N >inline
length_in() constpocketfft::detail::multi_iter< N >inline
length_out() constpocketfft::detail::multi_iter< N >inline
multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_)pocketfft::detail::multi_iter< N >inline
oofs(size_t i) constpocketfft::detail::multi_iter< N >inline
oofs(size_t j, size_t i) constpocketfft::detail::multi_iter< N >inline
remaining() constpocketfft::detail::multi_iter< N >inline
stride_in() constpocketfft::detail::multi_iter< N >inline
stride_out() constpocketfft::detail::multi_iter< N >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1multi__iter.html b/docs/build/html/classpocketfft_1_1detail_1_1multi__iter.html new file mode 100644 index 000000000..1558f0d92 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1multi__iter.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: pocketfft::detail::multi_iter< N > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::multi_iter< N > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 multi_iter (const arr_info &iarr_, const arr_info &oarr_, size_t idim_)
 
void advance (size_t n)
 
ptrdiff_t iofs (size_t i) const
 
ptrdiff_t iofs (size_t j, size_t i) const
 
ptrdiff_t oofs (size_t i) const
 
ptrdiff_t oofs (size_t j, size_t i) const
 
size_t length_in () const
 
size_t length_out () const
 
ptrdiff_t stride_in () const
 
ptrdiff_t stride_out () const
 
size_t remaining () const
 
+

Constructor & Destructor Documentation

+ +

◆ multi_iter()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + + + + + + + + + + +
pocketfft::detail::multi_iter< N >::multi_iter (const arr_info & iarr_,
const arr_info & oarr_,
size_t idim_ )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ advance()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
void pocketfft::detail::multi_iter< N >::advance (size_t n)
+
+inline
+
+ +
+
+ +

◆ iofs() [1/2]

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::iofs (size_t i) const
+
+inline
+
+ +
+
+ +

◆ iofs() [2/2]

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::iofs (size_t j,
size_t i ) const
+
+inline
+
+ +
+
+ +

◆ length_in()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::multi_iter< N >::length_in () const
+
+inline
+
+ +
+
+ +

◆ length_out()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::multi_iter< N >::length_out () const
+
+inline
+
+ +
+
+ +

◆ oofs() [1/2]

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::oofs (size_t i) const
+
+inline
+
+ +
+
+ +

◆ oofs() [2/2]

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::oofs (size_t j,
size_t i ) const
+
+inline
+
+ +
+
+ +

◆ remaining()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::multi_iter< N >::remaining () const
+
+inline
+
+ +
+
+ +

◆ stride_in()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::stride_in () const
+
+inline
+
+ +
+
+ +

◆ stride_out()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::stride_out () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1ndarr-members.html b/docs/build/html/classpocketfft_1_1detail_1_1ndarr-members.html new file mode 100644 index 000000000..695ada829 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1ndarr-members.html @@ -0,0 +1,104 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::ndarr< T > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::ndarr< T >, including all inherited members.

+ + + + + + + + + + + + + + + +
arr_info(const shape_t &shape_, const stride_t &stride_)pocketfft::detail::arr_infoinline
cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)pocketfft::detail::cndarr< T >inline
dpocketfft::detail::cndarr< T >protected
ndarr(void *data_, const shape_t &shape_, const stride_t &stride_)pocketfft::detail::ndarr< T >inline
ndim() constpocketfft::detail::arr_infoinline
operator[](ptrdiff_t ofs)pocketfft::detail::ndarr< T >inline
pocketfft::detail::cndarr::operator[](ptrdiff_t ofs) constpocketfft::detail::cndarr< T >inline
shape() constpocketfft::detail::arr_infoinline
shape(size_t i) constpocketfft::detail::arr_infoinline
shppocketfft::detail::arr_infoprotected
size() constpocketfft::detail::arr_infoinline
strpocketfft::detail::arr_infoprotected
stride() constpocketfft::detail::arr_infoinline
stride(size_t i) constpocketfft::detail::arr_infoinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1ndarr.html b/docs/build/html/classpocketfft_1_1detail_1_1ndarr.html new file mode 100644 index 000000000..515d2c773 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1ndarr.html @@ -0,0 +1,209 @@ + + + + + + + +MLX: pocketfft::detail::ndarr< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::ndarr< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+
+Inheritance diagram for pocketfft::detail::ndarr< T >:
+
+
+ + +pocketfft::detail::cndarr< T > +pocketfft::detail::arr_info + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ndarr (void *data_, const shape_t &shape_, const stride_t &stride_)
 
Toperator[] (ptrdiff_t ofs)
 
- Public Member Functions inherited from pocketfft::detail::cndarr< T >
 cndarr (const void *data_, const shape_t &shape_, const stride_t &stride_)
 
const Toperator[] (ptrdiff_t ofs) const
 
- Public Member Functions inherited from pocketfft::detail::arr_info
 arr_info (const shape_t &shape_, const stride_t &stride_)
 
size_t ndim () const
 
size_t size () const
 
const shape_tshape () const
 
size_t shape (size_t i) const
 
const stride_tstride () const
 
const ptrdiff_t & stride (size_t i) const
 
+ + + + + + + + + +

+Additional Inherited Members

- Protected Attributes inherited from pocketfft::detail::cndarr< T >
const chard
 
- Protected Attributes inherited from pocketfft::detail::arr_info
shape_t shp
 
stride_t str
 
+

Constructor & Destructor Documentation

+ +

◆ ndarr()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
pocketfft::detail::ndarr< T >::ndarr (void * data_,
const shape_t & shape_,
const stride_t & stride_ )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ operator[]()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
T & pocketfft::detail::ndarr< T >::operator[] (ptrdiff_t ofs)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1ndarr.png b/docs/build/html/classpocketfft_1_1detail_1_1ndarr.png new file mode 100644 index 000000000..96f688ccd Binary files /dev/null and b/docs/build/html/classpocketfft_1_1detail_1_1ndarr.png differ diff --git a/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c-members.html b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c-members.html new file mode 100644 index 000000000..cd88f36a7 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::pocketfft_c< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::pocketfft_c< T0 >, including all inherited members.

+ + + + +
exec(cmplx< T > c[], T0 fct, bool fwd) constpocketfft::detail::pocketfft_c< T0 >inline
length() constpocketfft::detail::pocketfft_c< T0 >inline
pocketfft_c(size_t length)pocketfft::detail::pocketfft_c< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c.html b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c.html new file mode 100644 index 000000000..8eac32ad6 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c.html @@ -0,0 +1,200 @@ + + + + + + + +MLX: pocketfft::detail::pocketfft_c< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::pocketfft_c< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 pocketfft_c (size_t length)
 
template<typename T >
void exec (cmplx< T > c[], T0 fct, bool fwd) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ pocketfft_c()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::pocketfft_c< T0 >::pocketfft_c (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::pocketfft_c< T0 >::exec (cmplx< T > c[],
T0 fct,
bool fwd ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::pocketfft_c< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r-members.html b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r-members.html new file mode 100644 index 000000000..14d5e1b2f --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::pocketfft_r< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::pocketfft_r< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool fwd) constpocketfft::detail::pocketfft_r< T0 >inline
length() constpocketfft::detail::pocketfft_r< T0 >inline
pocketfft_r(size_t length)pocketfft::detail::pocketfft_r< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r.html b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r.html new file mode 100644 index 000000000..7cf6d4e6d --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r.html @@ -0,0 +1,200 @@ + + + + + + + +MLX: pocketfft::detail::pocketfft_r< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::pocketfft_r< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 pocketfft_r (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool fwd) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ pocketfft_r()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::pocketfft_r< T0 >::pocketfft_r (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::pocketfft_r< T0 >::exec (T c[],
T0 fct,
bool fwd ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::pocketfft_r< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1rev__iter-members.html b/docs/build/html/classpocketfft_1_1detail_1_1rev__iter-members.html new file mode 100644 index 000000000..39dfa4688 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1rev__iter-members.html @@ -0,0 +1,95 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::rev_iter Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::rev_iter, including all inherited members.

+ + + + + + +
advance()pocketfft::detail::rev_iterinline
ofs() constpocketfft::detail::rev_iterinline
remaining() constpocketfft::detail::rev_iterinline
rev_iter(const arr_info &arr_, const shape_t &axes)pocketfft::detail::rev_iterinline
rev_ofs() constpocketfft::detail::rev_iterinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1rev__iter.html b/docs/build/html/classpocketfft_1_1detail_1_1rev__iter.html new file mode 100644 index 000000000..3775ffc19 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1rev__iter.html @@ -0,0 +1,240 @@ + + + + + + + +MLX: pocketfft::detail::rev_iter Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::rev_iter Class Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + + + +

+Public Member Functions

 rev_iter (const arr_info &arr_, const shape_t &axes)
 
void advance ()
 
ptrdiff_t ofs () const
 
ptrdiff_t rev_ofs () const
 
size_t remaining () const
 
+

Constructor & Destructor Documentation

+ +

◆ rev_iter()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
pocketfft::detail::rev_iter::rev_iter (const arr_info & arr_,
const shape_t & axes )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ advance()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::rev_iter::advance ()
+
+inline
+
+ +
+
+ +

◆ ofs()

+ +
+
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::rev_iter::ofs () const
+
+inline
+
+ +
+
+ +

◆ remaining()

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::rev_iter::remaining () const
+
+inline
+
+ +
+
+ +

◆ rev_ofs()

+ +
+
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::rev_iter::rev_ofs () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1rfftp-members.html b/docs/build/html/classpocketfft_1_1detail_1_1rfftp-members.html new file mode 100644 index 000000000..778f37aa0 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1rfftp-members.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::rfftp< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::rfftp< T0 >, including all inherited members.

+ + + +
exec(T c[], T0 fct, bool r2hc) constpocketfft::detail::rfftp< T0 >inline
rfftp(size_t length_)pocketfft::detail::rfftp< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1rfftp.html b/docs/build/html/classpocketfft_1_1detail_1_1rfftp.html new file mode 100644 index 000000000..deddc38d0 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1rfftp.html @@ -0,0 +1,171 @@ + + + + + + + +MLX: pocketfft::detail::rfftp< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::rfftp< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + +

+Public Member Functions

template<typename T >
void exec (T c[], T0 fct, bool r2hc) const
 
 rfftp (size_t length_)
 
+

Constructor & Destructor Documentation

+ +

◆ rfftp()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::rfftp< T0 >::rfftp (size_t length_)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::rfftp< T0 >::exec (T c[],
T0 fct,
bool r2hc ) const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1simple__iter-members.html b/docs/build/html/classpocketfft_1_1detail_1_1simple__iter-members.html new file mode 100644 index 000000000..0bd71a255 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1simple__iter-members.html @@ -0,0 +1,94 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::simple_iter Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::simple_iter, including all inherited members.

+ + + + + +
advance()pocketfft::detail::simple_iterinline
ofs() constpocketfft::detail::simple_iterinline
remaining() constpocketfft::detail::simple_iterinline
simple_iter(const arr_info &arr_)pocketfft::detail::simple_iterinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1simple__iter.html b/docs/build/html/classpocketfft_1_1detail_1_1simple__iter.html new file mode 100644 index 000000000..973bde1d8 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1simple__iter.html @@ -0,0 +1,209 @@ + + + + + + + +MLX: pocketfft::detail::simple_iter Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::simple_iter Class Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + +

+Public Member Functions

 simple_iter (const arr_info &arr_)
 
void advance ()
 
ptrdiff_t ofs () const
 
size_t remaining () const
 
+

Constructor & Destructor Documentation

+ +

◆ simple_iter()

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::simple_iter::simple_iter (const arr_info & arr_)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ advance()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::simple_iter::advance ()
+
+inline
+
+ +
+
+ +

◆ ofs()

+ +
+
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::simple_iter::ofs () const
+
+inline
+
+ +
+
+ +

◆ remaining()

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::simple_iter::remaining () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn-members.html b/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn-members.html new file mode 100644 index 000000000..8220ef462 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn-members.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::sincos_2pibyn< T > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::sincos_2pibyn< T >, including all inherited members.

+ + + +
operator[](size_t idx) constpocketfft::detail::sincos_2pibyn< T >inline
sincos_2pibyn(size_t n)pocketfft::detail::sincos_2pibyn< T >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn.html b/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn.html new file mode 100644 index 000000000..c39332681 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn.html @@ -0,0 +1,159 @@ + + + + + + + +MLX: pocketfft::detail::sincos_2pibyn< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::sincos_2pibyn< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + +

+Public Member Functions

 sincos_2pibyn (size_t n)
 
cmplx< Toperator[] (size_t idx) const
 
+

Constructor & Destructor Documentation

+ +

◆ sincos_2pibyn()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::sincos_2pibyn< T >::sincos_2pibyn (size_t n)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ operator[]()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
cmplx< T > pocketfft::detail::sincos_2pibyn< T >::operator[] (size_t idx) const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue-members.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue-members.html new file mode 100644 index 000000000..119b98e56 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::threading::concurrent_queue< T > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::threading::concurrent_queue< T >, including all inherited members.

+ + + + +
empty() constpocketfft::detail::threading::concurrent_queue< T >inline
push(T val)pocketfft::detail::threading::concurrent_queue< T >inline
try_pop(T &val)pocketfft::detail::threading::concurrent_queue< T >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html new file mode 100644 index 000000000..624dd5302 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html @@ -0,0 +1,187 @@ + + + + + + + +MLX: pocketfft::detail::threading::concurrent_queue< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::threading::concurrent_queue< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + +

+Public Member Functions

void push (T val)
 
bool try_pop (T &val)
 
bool empty () const
 
+

Member Function Documentation

+ +

◆ empty()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
bool pocketfft::detail::threading::concurrent_queue< T >::empty () const
+
+inline
+
+ +
+
+ +

◆ push()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::concurrent_queue< T >::push (T val)
+
+inline
+
+ +
+
+ +

◆ try_pop()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
bool pocketfft::detail::threading::concurrent_queue< T >::try_pop (T & val)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch-members.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch-members.html new file mode 100644 index 000000000..1a1651f32 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch-members.html @@ -0,0 +1,94 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::threading::latch Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::threading::latch, including all inherited members.

+ + + + + +
count_down()pocketfft::detail::threading::latchinline
is_ready()pocketfft::detail::threading::latchinline
latch(size_t n)pocketfft::detail::threading::latchinline
wait()pocketfft::detail::threading::latchinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch.html new file mode 100644 index 000000000..b4bc2ced1 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch.html @@ -0,0 +1,209 @@ + + + + + + + +MLX: pocketfft::detail::threading::latch Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::threading::latch Class Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + +

+Public Member Functions

 latch (size_t n)
 
void count_down ()
 
void wait ()
 
bool is_ready ()
 
+

Constructor & Destructor Documentation

+ +

◆ latch()

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::threading::latch::latch (size_t n)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ count_down()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::latch::count_down ()
+
+inline
+
+ +
+
+ +

◆ is_ready()

+ +
+
+ + + + + +
+ + + + + + + +
bool pocketfft::detail::threading::latch::is_ready ()
+
+inline
+
+ +
+
+ +

◆ wait()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::latch::wait ()
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool-members.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool-members.html new file mode 100644 index 000000000..a0affd7d6 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool-members.html @@ -0,0 +1,96 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::threading::thread_pool Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::threading::thread_pool, including all inherited members.

+ + + + + + + +
restart()pocketfft::detail::threading::thread_poolinline
shutdown()pocketfft::detail::threading::thread_poolinline
submit(std::function< void()> work)pocketfft::detail::threading::thread_poolinline
thread_pool(size_t nthreads)pocketfft::detail::threading::thread_poolinlineexplicit
thread_pool()pocketfft::detail::threading::thread_poolinline
~thread_pool()pocketfft::detail::threading::thread_poolinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool.html new file mode 100644 index 000000000..d09be49b4 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool.html @@ -0,0 +1,263 @@ + + + + + + + +MLX: pocketfft::detail::threading::thread_pool Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Public Member Functions | +List of all members
+
pocketfft::detail::threading::thread_pool Class Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + + + + + +

+Public Member Functions

 thread_pool (size_t nthreads)
 
 thread_pool ()
 
 ~thread_pool ()
 
void submit (std::function< void()> work)
 
void shutdown ()
 
void restart ()
 
+

Constructor & Destructor Documentation

+ +

◆ thread_pool() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::threading::thread_pool::thread_pool (size_t nthreads)
+
+inlineexplicit
+
+ +
+
+ +

◆ thread_pool() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::threading::thread_pool::thread_pool ()
+
+inline
+
+ +
+
+ +

◆ ~thread_pool()

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::threading::thread_pool::~thread_pool ()
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ restart()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::thread_pool::restart ()
+
+inline
+
+ +
+
+ +

◆ shutdown()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::thread_pool::shutdown ()
+
+inline
+
+ +
+
+ +

◆ submit()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::thread_pool::submit (std::function< void()> work)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/clipboard.js b/docs/build/html/clipboard.js new file mode 100644 index 000000000..42c1fb0e0 --- /dev/null +++ b/docs/build/html/clipboard.js @@ -0,0 +1,61 @@ +/** + +The code below is based on the Doxygen Awesome project, see +https://github.com/jothepro/doxygen-awesome-css + +MIT License + +Copyright (c) 2021 - 2022 jothepro + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +*/ + +let clipboard_title = "Copy to clipboard" +let clipboard_icon = `` +let clipboard_successIcon = `` +let clipboard_successDuration = 1000 + +$(function() { + if(navigator.clipboard) { + const fragments = document.getElementsByClassName("fragment") + for(const fragment of fragments) { + const clipboard_div = document.createElement("div") + clipboard_div.classList.add("clipboard") + clipboard_div.innerHTML = clipboard_icon + clipboard_div.title = clipboard_title + $(clipboard_div).click(function() { + const content = this.parentNode.cloneNode(true) + // filter out line number and folded fragments from file listings + content.querySelectorAll(".lineno, .ttc, .foldclosed").forEach((node) => { node.remove() }) + let text = content.textContent + // remove trailing newlines and trailing spaces from empty lines + text = text.replace(/^\s*\n/gm,'\n').replace(/\n*$/,'') + navigator.clipboard.writeText(text); + this.classList.add("success") + this.innerHTML = clipboard_successIcon + window.setTimeout(() => { // switch back to normal icon after timeout + this.classList.remove("success") + this.innerHTML = clipboard_icon + }, clipboard_successDuration); + }) + fragment.insertBefore(clipboard_div, fragment.firstChild) + } + } +}) diff --git a/docs/build/html/closed.png b/docs/build/html/closed.png new file mode 100644 index 000000000..98cc2c909 Binary files /dev/null and b/docs/build/html/closed.png differ diff --git a/docs/build/html/common_2binary_8h.html b/docs/build/html/common_2binary_8h.html new file mode 100644 index 000000000..b8410e482 --- /dev/null +++ b/docs/build/html/common_2binary_8h.html @@ -0,0 +1,117 @@ + + + + + + + +MLX: mlx/backend/common/binary.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces
+
binary.h File Reference
+
+
+
#include "mlx/allocator.h"
+#include "mlx/array.h"
+#include "mlx/backend/common/utils.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+

Variable Documentation

+ +

◆ op

+ +
+
+ + + + +
Op op
+
+ +
+
+
+ + + + diff --git a/docs/build/html/common_2binary_8h_source.html b/docs/build/html/common_2binary_8h_source.html new file mode 100644 index 000000000..10e707cd3 --- /dev/null +++ b/docs/build/html/common_2binary_8h_source.html @@ -0,0 +1,764 @@ + + + + + + + +MLX: mlx/backend/common/binary.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
binary.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4#include "mlx/allocator.h"
+
5#include "mlx/array.h"
+ +
7
+
8namespace mlx::core {
+
9
+
10namespace {
+
11
+
12enum class BinaryOpType {
+
13 ScalarScalar,
+
14 ScalarVector,
+
15 VectorScalar,
+
16 VectorVector,
+
17 General,
+
18};
+
19
+
20BinaryOpType get_binary_op_type(const array& a, const array& b) {
+
21 BinaryOpType bopt;
+
22 if (a.data_size() == 1 && b.data_size() == 1) {
+
23 bopt = BinaryOpType::ScalarScalar;
+
24 } else if (a.data_size() == 1 && b.flags().contiguous) {
+
25 bopt = BinaryOpType::ScalarVector;
+
26 } else if (b.data_size() == 1 && a.flags().contiguous) {
+
27 bopt = BinaryOpType::VectorScalar;
+
28 } else if (
+
29 a.flags().row_contiguous && b.flags().row_contiguous ||
+
30 a.flags().col_contiguous && b.flags().col_contiguous) {
+
31 bopt = BinaryOpType::VectorVector;
+
32 } else {
+
33 bopt = BinaryOpType::General;
+
34 }
+
35 return bopt;
+
36}
+
37
+
38void set_binary_op_output_data(
+
39 const array& a,
+
40 const array& b,
+
41 array& out,
+
42 BinaryOpType bopt,
+
43 bool donate_with_move = false) {
+
44 switch (bopt) {
+
45 case BinaryOpType::ScalarScalar:
+
46 out.set_data(
+
47 allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
+
48 break;
+
49 case BinaryOpType::ScalarVector:
+
50 if (b.is_donatable() && b.itemsize() == out.itemsize()) {
+
51 if (donate_with_move) {
+
52 out.move_shared_buffer(b);
+
53 } else {
+
54 out.copy_shared_buffer(b);
+
55 }
+
56 } else {
+
57 out.set_data(
+
58 allocator::malloc_or_wait(b.data_size() * out.itemsize()),
+
59 b.data_size(),
+
60 b.strides(),
+
61 b.flags());
+
62 }
+
63 break;
+
64 case BinaryOpType::VectorScalar:
+
65 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
+
66 if (donate_with_move) {
+
67 out.move_shared_buffer(a);
+
68 } else {
+
69 out.copy_shared_buffer(a);
+
70 }
+
71 } else {
+
72 out.set_data(
+
73 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
+
74 a.data_size(),
+
75 a.strides(),
+
76 a.flags());
+
77 }
+
78 break;
+
79 case BinaryOpType::VectorVector:
+
80 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
+
81 if (donate_with_move) {
+
82 out.move_shared_buffer(a);
+
83 } else {
+
84 out.copy_shared_buffer(a);
+
85 }
+
86 } else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
+
87 if (donate_with_move) {
+
88 out.move_shared_buffer(b);
+
89 } else {
+
90 out.copy_shared_buffer(b);
+
91 }
+
92 } else {
+
93 out.set_data(
+
94 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
+
95 a.data_size(),
+
96 a.strides(),
+
97 a.flags());
+
98 }
+
99 break;
+
100 case BinaryOpType::General:
+
101 if (a.is_donatable() && a.flags().row_contiguous &&
+
102 a.itemsize() == out.itemsize() && a.size() == out.size()) {
+
103 if (donate_with_move) {
+
104 out.move_shared_buffer(a);
+
105 } else {
+
106 out.copy_shared_buffer(a);
+
107 }
+
108 } else if (
+
109 b.is_donatable() && b.flags().row_contiguous &&
+
110 b.itemsize() == out.itemsize() && b.size() == out.size()) {
+
111 if (donate_with_move) {
+
112 out.move_shared_buffer(b);
+
113 } else {
+
114 out.copy_shared_buffer(b);
+
115 }
+
116 } else {
+
117 out.set_data(allocator::malloc_or_wait(out.nbytes()));
+
118 }
+
119 break;
+
120 }
+
121}
+
122
+
123struct UseDefaultBinaryOp {
+
124 template <typename T, typename U>
+
125 void operator()(const T* a, const T* b, U* dst, int size) {
+
126 // Should we throw? This should normally never be called.
+
127 assert(false);
+
128 }
+
129
+
130 template <typename T, typename U>
+
131 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
+
132 // Should we throw? This should normally never be called.
+
133 assert(false);
+
134 }
+
135};
+
136
+
137template <typename T, typename U, typename Op>
+
138struct DefaultVectorScalar {
+
139 Op op;
+
140
+
141 DefaultVectorScalar(Op op_) : op(op_) {}
+
142
+
143 void operator()(const T* a, const T* b, U* dst, int size) {
+
144 T scalar = *b;
+
145 while (size-- > 0) {
+
146 *dst = op(*a, scalar);
+
147 dst++;
+
148 a++;
+
149 }
+
150 }
+
151
+
152 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
+
153 T scalar = *b;
+
154 while (size-- > 0) {
+
155 auto dst = op(*a, scalar);
+
156 *dst_a = dst.first;
+
157 *dst_b = dst.second;
+
158 dst_a++;
+
159 dst_b++;
+
160 a++;
+
161 }
+
162 }
+
163};
+
164
+
165template <typename T, typename U, typename Op>
+
166struct DefaultScalarVector {
+
167 Op op;
+
168
+
169 DefaultScalarVector(Op op_) : op(op_) {}
+
170
+
171 void operator()(const T* a, const T* b, U* dst, int size) {
+
172 T scalar = *a;
+
173 while (size-- > 0) {
+
174 *dst = op(scalar, *b);
+
175 dst++;
+
176 b++;
+
177 }
+
178 }
+
179
+
180 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
+
181 T scalar = *a;
+
182 while (size-- > 0) {
+
183 auto dst = op(scalar, *b);
+
184 *dst_a = dst.first;
+
185 *dst_b = dst.second;
+
186 dst_a++;
+
187 dst_b++;
+
188 b++;
+
189 }
+
190 }
+
191};
+
192
+
193template <typename T, typename U, typename Op>
+
194struct DefaultVectorVector {
+
195 Op op;
+
196
+
197 DefaultVectorVector(Op op_) : op(op_) {}
+
198
+
199 void operator()(const T* a, const T* b, U* dst, int size) {
+
200 while (size-- > 0) {
+
201 *dst = op(*a, *b);
+
202 dst++;
+
203 a++;
+
204 b++;
+
205 }
+
206 }
+
207
+
208 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
+
209 while (size-- > 0) {
+
210 auto dst = op(*a, *b);
+
211 *dst_a = dst.first;
+
212 *dst_b = dst.second;
+
213 dst_a++;
+
214 dst_b++;
+
215 a++;
+
216 b++;
+
217 }
+
218 }
+
219};
+
220
+
221template <typename T, typename U, typename Op>
+
222void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
+
223 const T* a_ptr = a.data<T>();
+
224 const T* b_ptr = b.data<T>();
+
225 U* dst = out.data<U>();
+
226 size_t a_idx = 0;
+
227 size_t b_idx = 0;
+
228 for (size_t i = 0; i < out.size(); ++i) {
+
229 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
230 a_idx += a.strides()[0];
+
231 b_idx += b.strides()[0];
+
232 }
+
233}
+
234
+
235template <typename T, typename U, typename Op>
+
236void binary_op_dims1(
+
237 const array& a,
+
238 const array& b,
+
239 array& out,
+
240 Op op,
+
241 int stride) {
+
242 const T* a_ptr = a.data<T>();
+
243 const T* b_ptr = b.data<T>();
+
244 U* dst = out.data<U>();
+
245 size_t a_idx = 0;
+
246 size_t b_idx = 0;
+
247 for (size_t i = 0; i < a.shape()[0]; i++) {
+
248 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
+
249 a_idx += a.strides()[0];
+
250 b_idx += b.strides()[0];
+
251 dst += stride;
+
252 }
+
253}
+
254
+
255template <typename T, typename U, typename Op>
+
256void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
+
257 const T* a_ptr = a.data<T>();
+
258 const T* b_ptr = b.data<T>();
+
259 U* dst = out.data<U>();
+
260 size_t a_idx = 0;
+
261 size_t b_idx = 0;
+
262 size_t out_idx = 0;
+
263 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
264 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
265 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
266 a_idx += a.strides()[1];
+
267 b_idx += b.strides()[1];
+
268 }
+
269 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
270 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
271 }
+
272}
+
273
+
274template <typename T, typename U, typename Op>
+
275void binary_op_dims2(
+
276 const array& a,
+
277 const array& b,
+
278 array& out,
+
279 Op op,
+
280 int stride) {
+
281 const T* a_ptr = a.data<T>();
+
282 const T* b_ptr = b.data<T>();
+
283 U* dst = out.data<U>();
+
284 size_t a_idx = 0;
+
285 size_t b_idx = 0;
+
286 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
287 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
288 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
+
289 a_idx += a.strides()[1];
+
290 b_idx += b.strides()[1];
+
291 dst += stride;
+
292 }
+
293 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
294 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
295 }
+
296}
+
297
+
298template <typename T, typename U, typename Op>
+
299void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
+
300 const T* a_ptr = a.data<T>();
+
301 const T* b_ptr = b.data<T>();
+
302 U* dst = out.data<U>();
+
303 size_t a_idx = 0;
+
304 size_t b_idx = 0;
+
305 size_t out_idx = 0;
+
306 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
307 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
308 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
309 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
310 a_idx += a.strides()[2];
+
311 b_idx += b.strides()[2];
+
312 }
+
313 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
314 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
315 }
+
316 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
317 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
318 }
+
319}
+
320
+
321template <typename T, typename U, typename Op>
+
322void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
+
323 const T* a_ptr = a.data<T>();
+
324 const T* b_ptr = b.data<T>();
+
325 U* dst = out.data<U>();
+
326 size_t a_idx = 0;
+
327 size_t b_idx = 0;
+
328 size_t out_idx = 0;
+
329 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
330 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
331 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
332 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
+
333 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
334 a_idx += a.strides()[3];
+
335 b_idx += b.strides()[3];
+
336 }
+
337 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
+
338 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
+
339 }
+
340 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
341 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
342 }
+
343 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
344 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
345 }
+
346}
+
347
+
348template <typename T, typename U, typename Op>
+
349void binary_op_dispatch_dims(
+
350 const array& a,
+
351 const array& b,
+
352 array& out,
+
353 Op op) {
+
354 switch (out.ndim()) {
+
355 case 1:
+
356 binary_op_dims1<T, U, Op>(a, b, out, op);
+
357 return;
+
358 case 2:
+
359 binary_op_dims2<T, U, Op>(a, b, out, op);
+
360 return;
+
361 case 3:
+
362 binary_op_dims3<T, U, Op>(a, b, out, op);
+
363 return;
+
364 case 4:
+
365 binary_op_dims4<T, U, Op>(a, b, out, op);
+
366 return;
+
367 }
+
368
+
369 const T* a_ptr = a.data<T>();
+
370 const T* b_ptr = b.data<T>();
+
371 U* dst = out.data<U>();
+
372 for (size_t i = 0; i < out.size(); i++) {
+
373 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
374 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
375 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
376 }
+
377}
+
378
+
379template <typename T, typename U, typename Op>
+
380void binary_op_dispatch_dims(
+
381 const array& a,
+
382 const array& b,
+
383 array& out,
+
384 Op op,
+
385 int dim,
+
386 int stride) {
+
387 // Number of dimensions to loop over for vectorized ops
+
388 switch (dim) {
+
389 case 1:
+
390 binary_op_dims1<T, U, Op>(a, b, out, op, stride);
+
391 return;
+
392 case 2:
+
393 binary_op_dims2<T, U, Op>(a, b, out, op, stride);
+
394 return;
+
395 }
+
396
+
397 const T* a_ptr = a.data<T>();
+
398 const T* b_ptr = b.data<T>();
+
399 U* dst = out.data<U>();
+
400 for (size_t i = 0; i < out.size(); i += stride) {
+
401 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
402 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
403 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
+
404 dst += stride;
+
405 }
+
406}
+
407
+
408template <
+
409 typename T,
+
410 typename U,
+
411 typename Op,
+
412 typename OpSV,
+
413 typename OpVS,
+
414 typename OpVV>
+
415void binary_op(
+
416 const array& a,
+
417 const array& b,
+
418 array& out,
+
419 Op op,
+
420 OpSV opsv,
+
421 OpVS opvs,
+
422 OpVV opvv) {
+
423 auto bopt = get_binary_op_type(a, b);
+
424 set_binary_op_output_data(a, b, out, bopt);
+
425
+
426 // The full computation is scalar scalar so call the base op once
+
427 if (bopt == BinaryOpType::ScalarScalar) {
+
428 *(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
+
429 return;
+
430 }
+
431
+
432 // The full computation is scalar vector so delegate to the op
+
433 if (bopt == BinaryOpType::ScalarVector) {
+
434 opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
+
435 return;
+
436 }
+
437
+
438 // The full computation is vector scalar so delegate to the op
+
439 if (bopt == BinaryOpType::VectorScalar) {
+
440 opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
+
441 return;
+
442 }
+
443
+
444 // The full computation is vector vector so delegate to the op
+
445 if (bopt == BinaryOpType::VectorVector) {
+
446 opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
+
447 return;
+
448 }
+
449
+
450 // General computation so let's try to optimize
+
451
+
452 // Get the left-most dim such that the array is row contiguous after
+
453 auto& strides = out.strides();
+
454 auto leftmost_rc_dim = [&strides](const array& arr) {
+
455 int d = arr.ndim() - 1;
+
456 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
+
457 }
+
458 return d + 1;
+
459 };
+
460 auto a_rc_dim = leftmost_rc_dim(a);
+
461 auto b_rc_dim = leftmost_rc_dim(b);
+
462
+
463 // Get the left-most dim such that the array is a broadcasted "scalar" after
+
464 auto leftmost_s_dim = [](const array& arr) {
+
465 int d = arr.ndim() - 1;
+
466 for (; d >= 0 && arr.strides()[d] == 0; d--) {
+
467 }
+
468 return d + 1;
+
469 };
+
470 auto a_s_dim = leftmost_s_dim(a);
+
471 auto b_s_dim = leftmost_s_dim(b);
+
472
+
473 auto ndim = out.ndim();
+
474
+
475 // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
+
476 int dim = ndim;
+
477 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
+
478 bopt = BinaryOpType::VectorVector;
+
479 dim = d;
+
480 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
+
481 // contiguous
+
482 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
+
483 bopt = BinaryOpType::VectorScalar;
+
484 dim = d;
+
485 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
+
486 // contiguous
+
487 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
+
488 bopt = BinaryOpType::ScalarVector;
+
489 dim = d;
+
490 }
+
491
+
492 // Can be sure dim > 0 since otherwise we would have used one of the fully
+
493 // contiguous methods above. Except for the case that the flags do not
+
494 // correspond to the underlying contiguity.
+
495 size_t stride;
+
496 if (dim == 0 || strides[dim - 1] < 16) {
+
497 stride = 1;
+
498 bopt = BinaryOpType::General;
+
499 dim = ndim;
+
500 } else {
+
501 stride = strides[dim - 1];
+
502 }
+
503
+
504 switch (bopt) {
+
505 case BinaryOpType::VectorVector:
+
506 binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
+
507 break;
+
508 case BinaryOpType::VectorScalar:
+
509 binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
+
510 break;
+
511 case BinaryOpType::ScalarVector:
+
512 binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
+
513 break;
+
514 default:
+
515 binary_op_dispatch_dims<T, U>(a, b, out, op);
+
516 break;
+
517 }
+
518}
+
519
+
520template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
+
521void binary_op(
+
522 const array& a,
+
523 const array& b,
+
524 array& out,
+
525 Op op,
+
526 OpSV opsv,
+
527 OpVS opvs,
+
528 OpVV opvv) {
+
529 // TODO: The following mess of constexpr evaluations can probably be achieved
+
530 // with template specializations and overloading. Would it be simpler?
+
531
+
532 if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
+
533 if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
+
534 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
535 // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
+
536 binary_op<T, T>(
+
537 a,
+
538 b,
+
539 out,
+
540 op,
+
541 DefaultScalarVector<T, T, Op>(op),
+
542 DefaultVectorScalar<T, T, Op>(op),
+
543 DefaultVectorVector<T, T, Op>(op));
+
544 } else {
+
545 // opsv and opvs were UseDefaultBinaryOp
+
546 binary_op<T, T>(
+
547 a,
+
548 b,
+
549 out,
+
550 op,
+
551 DefaultScalarVector<T, T, Op>(op),
+
552 DefaultVectorScalar<T, T, Op>(op),
+
553 opvv);
+
554 }
+
555 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
556 // opsv and opvv were UseDefaultBinaryOp
+
557 binary_op<T, T>(
+
558 a,
+
559 b,
+
560 out,
+
561 op,
+
562 DefaultScalarVector<T, T, Op>(op),
+
563 opvs,
+
564 DefaultVectorVector<T, T, Op>(op));
+
565 } else {
+
566 // opsv was UseDefaultBinaryOp
+
567 binary_op<T, T>(
+
568 a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
+
569 }
+
570 } else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
+
571 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
572 // opvs and opvv were UseDefaultBinaryOp
+
573 binary_op<T, T>(
+
574 a,
+
575 b,
+
576 out,
+
577 op,
+
578 opsv,
+
579 DefaultVectorScalar<T, T, Op>(op),
+
580 DefaultVectorVector<T, T, Op>(op));
+
581 } else {
+
582 // opvs was UseDefaultBinaryOp
+
583 binary_op<T, T>(
+
584 a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
+
585 }
+
586 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
587 // opvv was UseDefaultBinaryOp
+
588 binary_op<T, T>(
+
589 a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
+
590 } else {
+
591 // All ops provided
+
592 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
+
593 }
+
594}
+
595
+
596template <typename T, typename Op>
+
597void binary_op(const array& a, const array& b, array& out, Op op) {
+
598 DefaultScalarVector<T, T, Op> opsv(op);
+
599 DefaultVectorScalar<T, T, Op> opvs(op);
+
600 DefaultVectorVector<T, T, Op> opvv(op);
+
601 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
+
602}
+
603
+
604template <typename... Ops>
+
605void binary(const array& a, const array& b, array& out, Ops... ops) {
+
606 switch (out.dtype()) {
+
607 case bool_:
+
608 binary_op<bool>(a, b, out, ops...);
+
609 break;
+
610 case uint8:
+
611 binary_op<uint8_t>(a, b, out, ops...);
+
612 break;
+
613 case uint16:
+
614 binary_op<uint16_t>(a, b, out, ops...);
+
615 break;
+
616 case uint32:
+
617 binary_op<uint32_t>(a, b, out, ops...);
+
618 break;
+
619 case uint64:
+
620 binary_op<uint64_t>(a, b, out, ops...);
+
621 break;
+
622 case int8:
+
623 binary_op<int8_t>(a, b, out, ops...);
+
624 break;
+
625 case int16:
+
626 binary_op<int16_t>(a, b, out, ops...);
+
627 break;
+
628 case int32:
+
629 binary_op<int32_t>(a, b, out, ops...);
+
630 break;
+
631 case int64:
+
632 binary_op<int64_t>(a, b, out, ops...);
+
633 break;
+
634 case float16:
+
635 binary_op<float16_t>(a, b, out, ops...);
+
636 break;
+
637 case float32:
+
638 binary_op<float>(a, b, out, ops...);
+
639 break;
+
640 case bfloat16:
+
641 binary_op<bfloat16_t>(a, b, out, ops...);
+
642 break;
+
643 case complex64:
+
644 binary_op<complex64_t>(a, b, out, ops...);
+
645 break;
+
646 }
+
647}
+
648
+
649} // namespace
+
650
+
651} // namespace mlx::core
+ + + +
Op op
Definition binary.h:139
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+ +
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+
+ + + + diff --git a/docs/build/html/common_2compiled__preamble_8h.html b/docs/build/html/common_2compiled__preamble_8h.html new file mode 100644 index 000000000..824d31aae --- /dev/null +++ b/docs/build/html/common_2compiled__preamble_8h.html @@ -0,0 +1,118 @@ + + + + + + + +MLX: mlx/backend/common/compiled_preamble.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Functions
+
compiled_preamble.h File Reference
+
+
+
#include "mlx/types/half_types.h"
+#include "mlx/types/complex.h"
+#include "mlx/backend/common/ops.h"
+
+

Go to the source code of this file.

+ + + + +

+Functions

const char * get_kernel_preamble ()
 
+

Function Documentation

+ +

◆ get_kernel_preamble()

+ +
+
+ + + + + + + +
const char * get_kernel_preamble ()
+
+ +
+
+
+ + + + diff --git a/docs/build/html/common_2compiled__preamble_8h_source.html b/docs/build/html/common_2compiled__preamble_8h_source.html new file mode 100644 index 000000000..3fecd12e0 --- /dev/null +++ b/docs/build/html/common_2compiled__preamble_8h_source.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: mlx/backend/common/compiled_preamble.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
compiled_preamble.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-24 Apple Inc.
+
2
+
3#pragma once
+
4
+
5// clang-format off
+ +
7#include "mlx/types/complex.h"
+ +
9// clang-format on
+
10
+
11const char* get_kernel_preamble();
+ +
const char * get_kernel_preamble()
+ + +
+ + + + diff --git a/docs/build/html/common_2copy_8h.html b/docs/build/html/common_2copy_8h.html new file mode 100644 index 000000000..2a90f4911 --- /dev/null +++ b/docs/build/html/common_2copy_8h.html @@ -0,0 +1,122 @@ + + + + + + + +MLX: mlx/backend/common/copy.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces | +Enumerations | +Functions
+
copy.h File Reference
+
+
+
#include "mlx/array.h"
+#include "mlx/backend/common/utils.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Enumerations

enum class  mlx::core::CopyType { mlx::core::Scalar +, mlx::core::Vector +, mlx::core::General +, mlx::core::GeneralGeneral + }
 
+ + + + + + + + +

+Functions

void mlx::core::copy (const array &src, array &dst, CopyType ctype)
 
void mlx::core::copy_inplace (const array &src, array &dst, CopyType ctype)
 
template<typename stride_t >
void mlx::core::copy_inplace (const array &src, array &dst, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)
 
+
+ + + + diff --git a/docs/build/html/common_2copy_8h_source.html b/docs/build/html/common_2copy_8h_source.html new file mode 100644 index 000000000..050fc5aa9 --- /dev/null +++ b/docs/build/html/common_2copy_8h_source.html @@ -0,0 +1,145 @@ + + + + + + + +MLX: mlx/backend/common/copy.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
copy.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/array.h"
+ +
7
+
8namespace mlx::core {
+
9
+
+
10enum class CopyType {
+
11 // Copy a raw scalar input into the full contiguous output
+
12 Scalar,
+
13
+
14 // Copy the raw input buffer contiguously into a raw output buffer of the same
+
15 // size
+
16 Vector,
+
17
+
18 // Copy the full virtual input to the full contiguous output
+
19 General,
+
20
+
21 // Copy the full virtual input to the full virtual output. We assume the
+
22 // input and output have the same shape.
+ +
24};
+
+
25
+
26void copy(const array& src, array& dst, CopyType ctype);
+
27void copy_inplace(const array& src, array& dst, CopyType ctype);
+
28
+
29template <typename stride_t>
+ +
31 const array& src,
+
32 array& dst,
+
33 const std::vector<int>& data_shape,
+
34 const std::vector<stride_t>& i_strides,
+
35 const std::vector<stride_t>& o_strides,
+
36 int64_t i_offset,
+
37 int64_t o_offset,
+
38 CopyType ctype);
+
39
+
40} // namespace mlx::core
+ + +
Definition array.h:20
+
Definition allocator.h:7
+
void copy(const array &src, array &dst, CopyType ctype)
+
void copy_inplace(const array &src, array &dst, CopyType ctype)
+
CopyType
Definition copy.h:10
+ + + + +
+ + + + diff --git a/docs/build/html/common_2reduce_8h.html b/docs/build/html/common_2reduce_8h.html new file mode 100644 index 000000000..951bd9ebd --- /dev/null +++ b/docs/build/html/common_2reduce_8h.html @@ -0,0 +1,136 @@ + + + + + + + +MLX: mlx/backend/common/reduce.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Namespaces | +Enumerations
+
reduce.h File Reference
+
+
+
#include "mlx/backend/common/utils.h"
+
+

Go to the source code of this file.

+ + + + +

+Classes

struct  mlx::core::ReductionPlan
 
+ + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Enumerations

enum  mlx::core::ReductionOpType {
+  mlx::core::ContiguousAllReduce +, mlx::core::ContiguousReduce +, mlx::core::ContiguousStridedReduce +, mlx::core::GeneralContiguousReduce +,
+  mlx::core::GeneralStridedReduce +, mlx::core::GeneralReduce +
+ }
 
+

Variable Documentation

+ +

◆ op

+ +
+
+ + + + +
Op op
+
+ +
+
+
+ + + + diff --git a/docs/build/html/common_2reduce_8h_source.html b/docs/build/html/common_2reduce_8h_source.html new file mode 100644 index 000000000..813b7d68b --- /dev/null +++ b/docs/build/html/common_2reduce_8h_source.html @@ -0,0 +1,483 @@ + + + + + + + +MLX: mlx/backend/common/reduce.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
reduce.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+ +
6
+
7namespace mlx::core {
+
8
+
+ +
10 // Self-explanatory. Read everything and produce 1 output.
+ +
12
+
13 // The input is contiguous and the last axis is reduced
+
14 // N1xR1xN2xR2x...xNnxRn
+ +
16
+
17 // The input is contiguous and the last axis is not reduced
+
18 // R1xN1xR2xN2x...xRnxNn
+ +
20
+
21 // The input is not contiguous but the last axis is and it is reduced so we
+
22 // need to figure out the offsets but we can call the contiguous reduce after
+
23 // that.
+
24 // N3xR1xN1xR4x...xRn
+ +
26
+
27 // The input is not contiguous but the last reduction axis and the last axis
+
28 // are so we need to figure out the offset but we can call the strided reduce
+
29 // after that.
+ +
31
+
32 // The input is not contiguous after the reduction axis and it may contain
+
33 // 0-stride axes or transpositions. We could copy the strides and produce a
+
34 // transposed outcome or we can read the input out of order and write the
+
35 // output in order.
+ +
37};
+
+
38
+
+ + +
41 std::vector<int> shape;
+
42 std::vector<size_t> strides;
+
43
+
+ +
45 ReductionOpType type_,
+
46 std::vector<int> shape_,
+
47 std::vector<size_t> strides_)
+
48 : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
+
+ +
50};
+
+
51
+
52namespace {
+
53
+
54// Helper for the ndimensional strided loop
+
55// Should this be in utils?
+
56inline void nd_loop(
+
57 std::function<void(int)> callback,
+
58 const std::vector<int>& shape,
+
59 const std::vector<size_t>& strides) {
+
60 std::function<void(int, int)> loop_inner;
+
61 loop_inner = [&](int dim, int offset) {
+
62 if (dim < shape.size() - 1) {
+
63 int size = shape[dim];
+
64 size_t stride = strides[dim];
+
65 for (int i = 0; i < size; i++) {
+
66 loop_inner(dim + 1, offset + i * stride);
+
67 }
+
68 } else {
+
69 int size = shape[dim];
+
70 size_t stride = strides[dim];
+
71 for (int i = 0; i < size; i++) {
+
72 callback(offset + i * stride);
+
73 }
+
74 }
+
75 };
+
76 loop_inner(0, 0);
+
77}
+
78
+
79std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
+
80 const array& x,
+
81 const std::vector<int>& axes) {
+
82 std::vector<int> shape = x.shape();
+
83 std::vector<size_t> strides = x.strides();
+
84
+
85 for (int i = axes.size() - 1; i >= 0; i--) {
+
86 int a = axes[i];
+
87 shape.erase(shape.begin() + a);
+
88 strides.erase(strides.begin() + a);
+
89 }
+
90
+
91 return std::make_pair(shape, strides);
+
92}
+
93
+
94template <typename T, typename U, typename Op>
+
95struct DefaultStridedReduce {
+
96 Op op;
+
97
+
98 DefaultStridedReduce(Op op_) : op(op_) {}
+
99
+
100 void operator()(const T* x, U* accumulator, int size, size_t stride) {
+
101 for (int i = 0; i < size; i++) {
+
102 U* moving_accumulator = accumulator;
+
103 for (int j = 0; j < stride; j++) {
+
104 op(moving_accumulator, *x);
+
105 moving_accumulator++;
+
106 x++;
+
107 }
+
108 }
+
109 }
+
110};
+
111
+
112template <typename T, typename U, typename Op>
+
113struct DefaultContiguousReduce {
+
114 Op op;
+
115
+
116 DefaultContiguousReduce(Op op_) : op(op_) {}
+
117
+
118 void operator()(const T* x, U* accumulator, int size) {
+
119 while (size-- > 0) {
+
120 op(accumulator, *x);
+
121 x++;
+
122 }
+
123 }
+
124};
+
125
+
126ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
+
127 // The data is all there and we are reducing over everything
+
128 if (x.size() == x.data_size() && axes.size() == x.ndim() &&
+
129 x.flags().contiguous) {
+
130 return ContiguousAllReduce;
+
131 }
+
132
+
133 // Row contiguous input so the output is row contiguous
+
134 if (x.flags().row_contiguous) {
+
135 // Merge consecutive axes
+
136 std::vector<int> shape = {x.shape(axes[0])};
+
137 std::vector<size_t> strides = {x.strides()[axes[0]]};
+
138 for (int i = 1; i < axes.size(); i++) {
+
139 if (axes[i] - 1 == axes[i - 1]) {
+
140 shape.back() *= x.shape(axes[i]);
+
141 strides.back() = x.strides()[axes[i]];
+
142 } else {
+
143 shape.push_back(x.shape(axes[i]));
+
144 strides.push_back(x.strides()[axes[i]]);
+
145 }
+
146 }
+
147
+
148 if (strides.back() == 1) {
+
149 return ReductionPlan(ContiguousReduce, shape, strides);
+
150 } else if (strides.back() > 1) {
+
151 return ReductionPlan(ContiguousStridedReduce, shape, strides);
+
152 }
+
153 }
+
154
+
155 // Let's check if we can optimize our access patterns
+
156 //
+
157 // 1. We have a reduction axis with stride 1. Simply call
+
158 // GeneralContiguousReduce and be done with it.
+
159 // 2. We have transpositions and we are not reducing over the axis with
+
160 // stride 1. However, we are reducing over an axis where everything is
+
161 // contiguous in memory to the right of that axis. We can call strided
+
162 // reduce and be done with it.
+
163 // 2. We have weird transpositions and expands. Copy the strides to the
+
164 // output, then call strided reduce.
+
165
+
166 // Sort reduction axes by stride in order to merge them and figure out if we
+
167 // have a contiguous reduction.
+
168 std::vector<std::pair<int, size_t>> reductions;
+
169 for (auto a : axes) {
+
170 reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
+
171 }
+
172 std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
+
173 return a.second > b.second;
+
174 });
+
175 // Extract the two smallest and try to merge them in case the contiguous
+
176 // reduction can be bigger than just the last axis.
+
177 for (int i = reductions.size() - 1; i >= 1; i--) {
+
178 auto a = reductions[i];
+
179 auto b = reductions[i - 1];
+
180
+
181 // b.stride = a.shape * a.stride then a and b are contiguous
+
182 if (b.second == a.first * a.second) {
+
183 reductions.erase(reductions.begin() + i);
+
184 reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
+
185 }
+
186 }
+
187
+
188 std::vector<int> shape;
+
189 std::vector<size_t> strides;
+
190 for (auto r : reductions) {
+
191 shape.push_back(r.first);
+
192 strides.push_back(r.second);
+
193 }
+
194
+
195 // We can call the contiguous reduction op for every weird way the input is
+
196 // structured in the rest of the axes.
+
197 if (strides.back() == 1) {
+
198 return ReductionPlan(GeneralContiguousReduce, shape, strides);
+
199 }
+
200
+
201 // Delegate to the general strided reduction op if the axes after
+
202 // strides.back() are contiguous.
+
203 if (strides.back() > 1) {
+
204 int size = 1;
+
205 for (int i = x.ndim() - 1; i >= 0; i--) {
+
206 if (axes.back() == i) {
+
207 continue;
+
208 }
+
209 if (x.strides()[i] != size) {
+
210 break;
+
211 }
+
212 size *= x.shape(i);
+
213 }
+
214 if (size >= strides.back()) {
+
215 return ReductionPlan(GeneralStridedReduce, shape, strides);
+
216 }
+
217 }
+
218
+
219 return ReductionPlan(GeneralReduce, shape, strides);
+
220}
+
221
+
222template <typename T, typename U, typename OpS, typename OpC, typename Op>
+
223void reduction_op(
+
224 const array& x,
+
225 array& out,
+
226 const std::vector<int>& axes,
+
227 U init,
+
228 OpS ops,
+
229 OpC opc,
+
230 Op op) {
+
231 out.set_data(allocator::malloc_or_wait(out.nbytes()));
+
232 ReductionPlan plan = get_reduction_plan(x, axes);
+
233
+
234 if (plan.type == ContiguousAllReduce) {
+
235 U* out_ptr = out.data<U>();
+
236 *out_ptr = init;
+
237 opc(x.data<T>(), out_ptr, x.size());
+
238 return;
+
239 }
+
240
+
241 std::vector<int> shape;
+
242 std::vector<size_t> strides;
+
243
+
244 if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
+
245 int reduction_size = plan.shape[0];
+
246 const T* x_ptr = x.data<T>();
+
247 U* out_ptr = out.data<U>();
+
248 for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
+
249 *out_ptr = init;
+
250 opc(x_ptr, out_ptr, reduction_size);
+
251 }
+
252 return;
+
253 }
+
254
+
255 if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
+
256 int reduction_size = plan.shape.back();
+
257 plan.shape.pop_back();
+
258 plan.strides.pop_back();
+
259 const T* x_ptr = x.data<T>();
+
260 U* out_ptr = out.data<U>();
+
261 // Unrolling the following loop (and implementing it in order for
+
262 // ContiguousReduce) should hold extra performance boost.
+
263 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
+
264 if (plan.shape.size() == 0) {
+
265 for (int i = 0; i < out.size(); i++, out_ptr++) {
+
266 int offset = elem_to_loc(i, shape, strides);
+
267 *out_ptr = init;
+
268 opc(x_ptr + offset, out_ptr, reduction_size);
+
269 }
+
270 } else {
+
271 for (int i = 0; i < out.size(); i++, out_ptr++) {
+
272 int offset = elem_to_loc(i, shape, strides);
+
273 *out_ptr = init;
+
274 nd_loop(
+
275 [&](int extra_offset) {
+
276 opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
+
277 },
+
278 plan.shape,
+
279 plan.strides);
+
280 }
+
281 }
+
282 return;
+
283 }
+
284
+
285 if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
+
286 int reduction_size = plan.shape.back();
+
287 size_t reduction_stride = plan.strides.back();
+
288 plan.shape.pop_back();
+
289 plan.strides.pop_back();
+
290 const T* x_ptr = x.data<T>();
+
291 U* out_ptr = out.data<U>();
+
292 for (int i = 0; i < out.size(); i += reduction_stride) {
+
293 std::fill_n(out_ptr, reduction_stride, init);
+
294 ops(x_ptr, out_ptr, reduction_size, reduction_stride);
+
295 x_ptr += reduction_stride * reduction_size;
+
296 out_ptr += reduction_stride;
+
297 }
+
298 return;
+
299 }
+
300
+
301 if (plan.type == GeneralStridedReduce ||
+
302 plan.type == ContiguousStridedReduce) {
+
303 int reduction_size = plan.shape.back();
+
304 size_t reduction_stride = plan.strides.back();
+
305 plan.shape.pop_back();
+
306 plan.strides.pop_back();
+
307 const T* x_ptr = x.data<T>();
+
308 U* out_ptr = out.data<U>();
+
309 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
+
310 if (plan.shape.size() == 0) {
+
311 for (int i = 0; i < out.size(); i += reduction_stride) {
+
312 int offset = elem_to_loc(i, shape, strides);
+
313 std::fill_n(out_ptr, reduction_stride, init);
+
314 ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
+
315 out_ptr += reduction_stride;
+
316 }
+
317 } else {
+
318 for (int i = 0; i < out.size(); i += reduction_stride) {
+
319 int offset = elem_to_loc(i, shape, strides);
+
320 std::fill_n(out_ptr, reduction_stride, init);
+
321 nd_loop(
+
322 [&](int extra_offset) {
+
323 ops(x_ptr + offset + extra_offset,
+
324 out_ptr,
+
325 reduction_size,
+
326 reduction_stride);
+
327 },
+
328 plan.shape,
+
329 plan.strides);
+
330 out_ptr += reduction_stride;
+
331 }
+
332 }
+
333 return;
+
334 }
+
335
+
336 if (plan.type == GeneralReduce) {
+
337 const T* x_ptr = x.data<T>();
+
338 U* out_ptr = out.data<U>();
+
339 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
+
340 for (int i = 0; i < out.size(); i++, out_ptr++) {
+
341 int offset = elem_to_loc(i, shape, strides);
+
342 U val = init;
+
343 nd_loop(
+
344 [&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
+
345 plan.shape,
+
346 plan.strides);
+
347 *out_ptr = val;
+
348 }
+
349 }
+
350}
+
351
+
352template <typename T, typename U, typename Op>
+
353void reduction_op(
+
354 const array& x,
+
355 array& out,
+
356 const std::vector<int>& axes,
+
357 U init,
+
358 Op op) {
+
359 DefaultStridedReduce<T, U, Op> ops(op);
+
360 DefaultContiguousReduce<T, U, Op> opc(op);
+
361 reduction_op<T, U>(x, out, axes, init, ops, opc, op);
+
362}
+
363
+
364} // namespace
+
365
+
366} // namespace mlx::core
+ +
Op op
Definition binary.h:139
+
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
ReductionOpType
Definition reduce.h:9
+
@ GeneralReduce
Definition reduce.h:36
+
@ GeneralContiguousReduce
Definition reduce.h:25
+
@ ContiguousStridedReduce
Definition reduce.h:19
+
@ ContiguousReduce
Definition reduce.h:15
+
@ GeneralStridedReduce
Definition reduce.h:30
+
@ ContiguousAllReduce
Definition reduce.h:11
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
Definition reduce.h:39
+
ReductionOpType type
Definition reduce.h:40
+
ReductionPlan(ReductionOpType type_, std::vector< int > shape_, std::vector< size_t > strides_)
Definition reduce.h:44
+
std::vector< int > shape
Definition reduce.h:41
+
std::vector< size_t > strides
Definition reduce.h:42
+
ReductionPlan(ReductionOpType type_)
Definition reduce.h:49
+
+ + + + diff --git a/docs/build/html/common_2ternary_8h.html b/docs/build/html/common_2ternary_8h.html new file mode 100644 index 000000000..2bdd394cc --- /dev/null +++ b/docs/build/html/common_2ternary_8h.html @@ -0,0 +1,103 @@ + + + + + + + +MLX: mlx/backend/common/ternary.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces
+
ternary.h File Reference
+
+
+
#include "mlx/allocator.h"
+#include "mlx/array.h"
+#include "mlx/backend/common/ops.h"
+#include "mlx/backend/common/utils.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+
+ + + + diff --git a/docs/build/html/common_2ternary_8h_source.html b/docs/build/html/common_2ternary_8h_source.html new file mode 100644 index 000000000..edce1f51f --- /dev/null +++ b/docs/build/html/common_2ternary_8h_source.html @@ -0,0 +1,327 @@ + + + + + + + +MLX: mlx/backend/common/ternary.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
ternary.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4#include "mlx/allocator.h"
+
5#include "mlx/array.h"
+ + +
8namespace mlx::core {
+
9
+
10namespace {
+
11
+
12// TODO: Add support for more combinations of input types.
+
13enum class TernaryOpType {
+
14 ScalarScalarScalar,
+
15 General,
+
16};
+
17
+
18TernaryOpType
+
19get_ternary_op_type(const array& a, const array& b, const array& c) {
+
20 TernaryOpType topt;
+
21 if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
+
22 topt = TernaryOpType::ScalarScalarScalar;
+
23 } else {
+
24 topt = TernaryOpType::General;
+
25 }
+
26 return topt;
+
27}
+
28
+
29void set_ternary_op_output_data(
+
30 const array& a,
+
31 const array& b,
+
32 const array& c,
+
33 array& out,
+
34 TernaryOpType topt,
+
35 bool donate_with_move = false) {
+
36 switch (topt) {
+
37 case TernaryOpType::ScalarScalarScalar:
+
38 out.set_data(
+
39 allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
+
40 break;
+
41 case TernaryOpType::General:
+
42 out.set_data(allocator::malloc_or_wait(out.nbytes()));
+
43 break;
+
44 }
+
45}
+
46
+
47template <typename T1, typename T2, typename T3, typename U, typename Op>
+
48void ternary_op_dims1(
+
49 const array& a,
+
50 const array& b,
+
51 const array& c,
+
52 array& out,
+
53 Op op) {
+
54 const T1* a_ptr = a.data<T1>();
+
55 const T2* b_ptr = b.data<T2>();
+
56 const T3* c_ptr = c.data<T3>();
+
57
+
58 U* dst = out.data<U>();
+
59 size_t a_idx = 0;
+
60 size_t b_idx = 0;
+
61 size_t c_idx = 0;
+
62 for (size_t i = 0; i < out.size(); ++i) {
+
63 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
64 a_idx += a.strides()[0];
+
65 b_idx += b.strides()[0];
+
66 c_idx += c.strides()[0];
+
67 }
+
68}
+
69
+
70template <typename T1, typename T2, typename T3, typename U, typename Op>
+
71void ternary_op_dims2(
+
72 const array& a,
+
73 const array& b,
+
74 const array& c,
+
75 array& out,
+
76 Op op) {
+
77 const T1* a_ptr = a.data<T1>();
+
78 const T2* b_ptr = b.data<T2>();
+
79 const T3* c_ptr = c.data<T3>();
+
80
+
81 U* dst = out.data<U>();
+
82 size_t a_idx = 0;
+
83 size_t b_idx = 0;
+
84 size_t c_idx = 0;
+
85 size_t out_idx = 0;
+
86 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
87 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
88 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
89 a_idx += a.strides()[1];
+
90 b_idx += b.strides()[1];
+
91 c_idx += c.strides()[1];
+
92 }
+
93 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
94 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
95 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
+
96 }
+
97}
+
98
+
99template <typename T1, typename T2, typename T3, typename U, typename Op>
+
100void ternary_op_dims3(
+
101 const array& a,
+
102 const array& b,
+
103 const array& c,
+
104 array& out,
+
105 Op op) {
+
106 const T1* a_ptr = a.data<T1>();
+
107 const T2* b_ptr = b.data<T2>();
+
108 const T3* c_ptr = c.data<T3>();
+
109 U* dst = out.data<U>();
+
110 size_t a_idx = 0;
+
111 size_t b_idx = 0;
+
112 size_t c_idx = 0;
+
113 size_t out_idx = 0;
+
114 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
115 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
116 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
117 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
118 a_idx += a.strides()[2];
+
119 b_idx += b.strides()[2];
+
120 c_idx += c.strides()[2];
+
121 }
+
122 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
123 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
124 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
+
125 }
+
126 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
127 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
128 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
+
129 }
+
130}
+
131
+
132template <typename T1, typename T2, typename T3, typename U, typename Op>
+
133void ternary_op_dims4(
+
134 const array& a,
+
135 const array& b,
+
136 const array& c,
+
137 array& out,
+
138 Op op) {
+
139 const T1* a_ptr = a.data<T1>();
+
140 const T2* b_ptr = b.data<T2>();
+
141 const T3* c_ptr = c.data<T3>();
+
142
+
143 U* dst = out.data<U>();
+
144 size_t a_idx = 0;
+
145 size_t b_idx = 0;
+
146 size_t c_idx = 0;
+
147 size_t out_idx = 0;
+
148 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
149 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
150 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
151 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
+
152 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
153 a_idx += a.strides()[3];
+
154 b_idx += b.strides()[3];
+
155 c_idx += c.strides()[3];
+
156 }
+
157 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
+
158 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
+
159 c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
+
160 }
+
161 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
162 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
163 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
+
164 }
+
165 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
166 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
167 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
+
168 }
+
169}
+
170
+
171template <typename T1, typename T2, typename T3, typename U, typename Op>
+
172void ternary_op_dispatch_dims(
+
173 const array& a,
+
174 const array& b,
+
175 const array& c,
+
176 array& out,
+
177 Op op) {
+
178 switch (out.ndim()) {
+
179 case 1:
+
180 ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
+
181 return;
+
182 case 2:
+
183 ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
+
184 return;
+
185 case 3:
+
186 ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
+
187 return;
+
188 case 4:
+
189 ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
+
190 return;
+
191 }
+
192
+
193 const T1* a_ptr = a.data<T1>();
+
194 const T2* b_ptr = b.data<T2>();
+
195 const T3* c_ptr = c.data<T3>();
+
196 U* dst = out.data<U>();
+
197 for (size_t i = 0; i < out.size(); i++) {
+
198 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
199 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
200 int c_idx = elem_to_loc(i, c.shape(), c.strides());
+
201 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
202 }
+
203}
+
204
+
205template <typename T1, typename T2, typename T3, typename U, typename Op>
+
206void ternary_op(
+
207 const array& a,
+
208 const array& b,
+
209 const array& c,
+
210 array& out,
+
211 Op op) {
+
212 TernaryOpType topt = get_ternary_op_type(a, b, c);
+
213 set_ternary_op_output_data(a, b, c, out, topt);
+
214
+
215 // The full computation is scalar-scalar-scalar so we call the base op once.
+
216 if (topt == TernaryOpType::ScalarScalarScalar) {
+
217 *(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
+
218 return;
+
219 }
+
220
+
221 ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
+
222}
+
223
+
224} // namespace
+
225
+
226} // namespace mlx::core
+ + + + +
Op op
Definition binary.h:139
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+ +
+ + + + diff --git a/docs/build/html/common_2unary_8h.html b/docs/build/html/common_2unary_8h.html new file mode 100644 index 000000000..c0017b1c4 --- /dev/null +++ b/docs/build/html/common_2unary_8h.html @@ -0,0 +1,103 @@ + + + + + + + +MLX: mlx/backend/common/unary.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces
+
unary.h File Reference
+
+
+
#include "mlx/allocator.h"
+#include "mlx/array.h"
+#include "mlx/backend/common/utils.h"
+#include "mlx/utils.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+
+ + + + diff --git a/docs/build/html/common_2unary_8h_source.html b/docs/build/html/common_2unary_8h_source.html new file mode 100644 index 000000000..1fa0f5e15 --- /dev/null +++ b/docs/build/html/common_2unary_8h_source.html @@ -0,0 +1,229 @@ + + + + + + + +MLX: mlx/backend/common/unary.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
unary.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/allocator.h"
+
6#include "mlx/array.h"
+ +
8#include "mlx/utils.h"
+
9
+
10namespace mlx::core {
+
11
+
12namespace {
+
13
+
14void set_unary_output_data(const array& in, array& out) {
+
15 if (in.is_donatable() && in.itemsize() == out.itemsize()) {
+
16 out.copy_shared_buffer(in);
+
17 } else {
+
18 auto size = in.data_size();
+
19 out.set_data(
+
20 allocator::malloc_or_wait(size * out.itemsize()),
+
21 size,
+
22 in.strides(),
+
23 in.flags());
+
24 }
+
25}
+
26
+
27template <typename T, typename Op>
+
28void unary_op(const array& a, array& out, Op op) {
+
29 const T* a_ptr = a.data<T>();
+
30 if (a.flags().contiguous) {
+
31 set_unary_output_data(a, out);
+
32 T* dst = out.data<T>();
+
33 for (size_t i = 0; i < a.data_size(); ++i) {
+
34 dst[i] = op(a_ptr[i]);
+
35 }
+
36 } else {
+
37 out.set_data(allocator::malloc_or_wait(out.nbytes()));
+
38 T* dst = out.data<T>();
+
39 for (size_t i = 0; i < out.size(); ++i) {
+
40 // TODO this is super inefficient, need to fix.
+
41 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
42 dst[i] = op(a_ptr[a_idx]);
+
43 }
+
44 }
+
45}
+
46
+
47template <typename Op>
+
48void unary(const array& a, array& out, Op op) {
+
49 switch (out.dtype()) {
+
50 case bool_:
+
51 unary_op<bool>(a, out, op);
+
52 break;
+
53 case uint8:
+
54 unary_op<uint8_t>(a, out, op);
+
55 break;
+
56 case uint16:
+
57 unary_op<uint16_t>(a, out, op);
+
58 break;
+
59 case uint32:
+
60 unary_op<uint32_t>(a, out, op);
+
61 break;
+
62 case uint64:
+
63 unary_op<uint64_t>(a, out, op);
+
64 break;
+
65 case int8:
+
66 unary_op<int8_t>(a, out, op);
+
67 break;
+
68 case int16:
+
69 unary_op<int16_t>(a, out, op);
+
70 break;
+
71 case int32:
+
72 unary_op<int32_t>(a, out, op);
+
73 break;
+
74 case int64:
+
75 unary_op<int64_t>(a, out, op);
+
76 break;
+
77 case float16:
+
78 unary_op<float16_t>(a, out, op);
+
79 break;
+
80 case float32:
+
81 unary_op<float>(a, out, op);
+
82 break;
+
83 case bfloat16:
+
84 unary_op<bfloat16_t>(a, out, op);
+
85 break;
+
86 case complex64:
+
87 unary_op<complex64_t>(a, out, op);
+
88 break;
+
89 }
+
90}
+
91
+
92template <typename Op>
+
93void unary_fp(const array& a, array& out, Op op) {
+
94 switch (out.dtype()) {
+
95 case bfloat16:
+
96 unary_op<bfloat16_t>(a, out, op);
+
97 break;
+
98 case float16:
+
99 unary_op<float16_t>(a, out, op);
+
100 break;
+
101 case float32:
+
102 unary_op<float>(a, out, op);
+
103 break;
+
104 case complex64:
+
105 unary_op<complex64_t>(a, out, op);
+
106 break;
+
107 default:
+
108 std::ostringstream err;
+
109 err << "[unary_fp] Does not support " << out.dtype();
+
110 throw std::runtime_error(err.str());
+
111 }
+
112}
+
113
+
114} // namespace
+
115
+
116} // namespace mlx::core
+ + + +
Op op
Definition binary.h:139
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+ +
+ + + + diff --git a/docs/build/html/compile_8h.html b/docs/build/html/compile_8h.html new file mode 100644 index 000000000..3689b75f0 --- /dev/null +++ b/docs/build/html/compile_8h.html @@ -0,0 +1,130 @@ + + + + + + + +MLX: mlx/compile.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces | +Enumerations | +Functions | +Variables
+
compile.h File Reference
+
+
+
#include "mlx/array.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Enumerations

enum class  mlx::core::CompileMode { mlx::core::disabled +, mlx::core::no_simplify +, mlx::core::no_fuse +, mlx::core::enabled + }
 
+ + + + + + + + + + +

+Functions

void mlx::core::disable_compile ()
 Globally disable compilation.
 
void mlx::core::enable_compile ()
 Globally enable compilation.
 
void mlx::core::set_compile_mode (CompileMode mode)
 Set the compiler mode to the given value.
 
+ + + + +

+Variables

std::function< std::vector< array >(const std::vector< array > &) mlx::core::compile )(const std::function< std::vector< array >(const std::vector< array > &)> &fun, bool shapeless=false)
 Compile takes a function and returns a compiled function.
 
+
+ + + + diff --git a/docs/build/html/compile_8h_source.html b/docs/build/html/compile_8h_source.html new file mode 100644 index 000000000..8960d1ddc --- /dev/null +++ b/docs/build/html/compile_8h_source.html @@ -0,0 +1,123 @@ + + + + + + + +MLX: mlx/compile.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
compile.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/array.h"
+
6
+
7namespace mlx::core {
+
8
+ +
10
+
12std::function<std::vector<array>(const std::vector<array>&)> compile(
+
13 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
+
14 bool shapeless = false);
+
15
+ +
21
+ +
26
+ +
29} // namespace mlx::core
+ +
Definition allocator.h:7
+
void enable_compile()
Globally enable compilation.
+
void set_compile_mode(CompileMode mode)
Set the compiler mode to the given value.
+
void disable_compile()
Globally disable compilation.
+
std::function< std::vector< array >(const std::vector< array > &) compile)(const std::function< std::vector< array >(const std::vector< array > &)> &fun, bool shapeless=false)
Compile takes a function and returns a compiled function.
+
CompileMode
Definition compile.h:9
+ + + + +
+ + + + diff --git a/docs/build/html/compile__impl_8h.html b/docs/build/html/compile__impl_8h.html new file mode 100644 index 000000000..cfe23b4a1 --- /dev/null +++ b/docs/build/html/compile__impl_8h.html @@ -0,0 +1,108 @@ + + + + + + + +MLX: mlx/compile_impl.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces | +Functions
+
compile_impl.h File Reference
+
+
+
#include "mlx/device.h"
+
+

Go to the source code of this file.

+ + + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::detail
 
+ + + +

+Functions

bool mlx::core::detail::compile_available_for_device (const Device &device)
 
+
+ + + + diff --git a/docs/build/html/compile__impl_8h_source.html b/docs/build/html/compile__impl_8h_source.html new file mode 100644 index 000000000..83d84e37d --- /dev/null +++ b/docs/build/html/compile__impl_8h_source.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: mlx/compile_impl.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
compile_impl.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/device.h"
+
6
+
7namespace mlx::core::detail {
+
8
+ +
10
+
11}
+ +
Definition ops.h:8
+
bool compile_available_for_device(const Device &device)
+
Definition device.h:7
+
+ + + + diff --git a/docs/build/html/compiled_8h.html b/docs/build/html/compiled_8h.html new file mode 100644 index 000000000..aae146e42 --- /dev/null +++ b/docs/build/html/compiled_8h.html @@ -0,0 +1,131 @@ + + + + + + + +MLX: mlx/backend/common/compiled.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Namespaces | +Functions
+
compiled.h File Reference
+
+
+
#include <iomanip>
+#include <sstream>
+#include <unordered_set>
+#include "mlx/array.h"
+#include "mlx/primitives.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

bool mlx::core::is_static_cast (const Primitive &p)
 
std::string mlx::core::build_lib_name (const std::vector< array > &inputs, const std::vector< array > &outputs, const std::vector< array > &tape, const std::unordered_set< uintptr_t > &constant_ids)
 
std::string mlx::core::get_type_string (Dtype d)
 
template<typename T >
void mlx::core::print_float_constant (std::ostream &os, const array &x)
 
template<typename T >
void mlx::core::print_int_constant (std::ostream &os, const array &x)
 
template<typename T >
void mlx::core::print_complex_constant (std::ostream &os, const array &x)
 
void mlx::core::print_constant (std::ostream &os, const array &x)
 
bool mlx::core::is_scalar (const array &x)
 
bool mlx::core::compiled_check_contiguity (const std::vector< array > &inputs, const std::vector< int > &shape)
 
void mlx::core::compiled_allocate_outputs (const std::vector< array > &inputs, std::vector< array > &outputs, const std::vector< array > &inputs_, const std::unordered_set< uintptr_t > &constant_ids_, bool contiguous, bool move_buffers=false)
 
+
+ + + + diff --git a/docs/build/html/compiled_8h_source.html b/docs/build/html/compiled_8h_source.html new file mode 100644 index 000000000..6e37a563e --- /dev/null +++ b/docs/build/html/compiled_8h_source.html @@ -0,0 +1,195 @@ + + + + + + + +MLX: mlx/backend/common/compiled.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
compiled.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2#pragma once
+
3
+
4#include <iomanip>
+
5#include <sstream>
+
6#include <unordered_set>
+
7
+
8#include "mlx/array.h"
+
9#include "mlx/primitives.h"
+
10
+
11namespace mlx::core {
+
12
+
+
13inline bool is_static_cast(const Primitive& p) {
+
14 return (
+
15 typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
+
16 typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
+
17}
+
+
18
+
19std::string build_lib_name(
+
20 const std::vector<array>& inputs,
+
21 const std::vector<array>& outputs,
+
22 const std::vector<array>& tape,
+
23 const std::unordered_set<uintptr_t>& constant_ids);
+
24
+
25std::string get_type_string(Dtype d);
+
26
+
27template <typename T>
+
+
28void print_float_constant(std::ostream& os, const array& x) {
+
29 auto old_precision = os.precision();
+
30 os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
+
31 << x.item<T>() << std::setprecision(old_precision);
+
32}
+
+
33
+
34template <typename T>
+
+
35void print_int_constant(std::ostream& os, const array& x) {
+
36 os << x.item<T>();
+
37}
+
+
38
+
39template <typename T>
+
+
40void print_complex_constant(std::ostream& os, const array& x) {
+
41 auto old_precision = os.precision();
+
42 T constant = x.item<T>();
+
43
+
44 os << get_type_string(x.dtype()) << "("
+
45 << std::setprecision(std::numeric_limits<float>::digits10 + 1)
+
46 << constant.real() << ", " << constant.imag() << ")"
+
47 << std::setprecision(old_precision);
+
48}
+
+
49
+
50void print_constant(std::ostream& os, const array& x);
+
51
+
+
52inline bool is_scalar(const array& x) {
+
53 return x.ndim() == 0;
+
54}
+
+
55
+
56// Check if we can use a contiguous operation given inputs and the output shape
+ +
58 const std::vector<array>& inputs,
+
59 const std::vector<int>& shape);
+
60
+
61// Allocate space for the outputs possibly with input donation
+ +
63 const std::vector<array>& inputs,
+
64 std::vector<array>& outputs,
+
65 const std::vector<array>& inputs_,
+
66 const std::unordered_set<uintptr_t>& constant_ids_,
+
67 bool contiguous,
+
68 bool move_buffers = false);
+
69
+
70} // namespace mlx::core
+ +
Definition primitives.h:416
+
Definition primitives.h:525
+
Definition primitives.h:680
+
Definition primitives.h:48
+
Definition primitives.h:1919
+
Definition array.h:20
+
size_t ndim() const
The number of dimensions of the array.
Definition array.h:94
+
T item()
Get the value from a scalar array.
Definition array.h:489
+
Dtype dtype() const
Get the arrays data type.
Definition array.h:127
+
Definition allocator.h:7
+
void print_complex_constant(std::ostream &os, const array &x)
Definition compiled.h:40
+
bool compiled_check_contiguity(const std::vector< array > &inputs, const std::vector< int > &shape)
+
std::string build_lib_name(const std::vector< array > &inputs, const std::vector< array > &outputs, const std::vector< array > &tape, const std::unordered_set< uintptr_t > &constant_ids)
+
void print_constant(std::ostream &os, const array &x)
+
void print_float_constant(std::ostream &os, const array &x)
Definition compiled.h:28
+
void print_int_constant(std::ostream &os, const array &x)
Definition compiled.h:35
+
bool is_scalar(const array &x)
Definition compiled.h:52
+
void compiled_allocate_outputs(const std::vector< array > &inputs, std::vector< array > &outputs, const std::vector< array > &inputs_, const std::unordered_set< uintptr_t > &constant_ids_, bool contiguous, bool move_buffers=false)
+
std::string get_type_string(Dtype d)
+
bool is_static_cast(const Primitive &p)
Definition compiled.h:13
+ +
Definition dtype.h:15
+
+ + + + diff --git a/docs/build/html/conv_2loader_8h.html b/docs/build/html/conv_2loader_8h.html new file mode 100644 index 000000000..f5125528c --- /dev/null +++ b/docs/build/html/conv_2loader_8h.html @@ -0,0 +1,91 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/loader.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
loader.h File Reference
+
+
+
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h"
+#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h"
+
+

Go to the source code of this file.

+
+ + + + diff --git a/docs/build/html/conv_2loader_8h_source.html b/docs/build/html/conv_2loader_8h_source.html new file mode 100644 index 000000000..b3bb84bc8 --- /dev/null +++ b/docs/build/html/conv_2loader_8h_source.html @@ -0,0 +1,100 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/loader.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
loader.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+ + + + +
+ + + + diff --git a/docs/build/html/conv_2params_8h.html b/docs/build/html/conv_2params_8h.html new file mode 100644 index 000000000..72736b010 --- /dev/null +++ b/docs/build/html/conv_2params_8h.html @@ -0,0 +1,111 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/params.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
+Classes | +Namespaces
+
params.h File Reference
+
+
+ +

Go to the source code of this file.

+ + + + + + + + + + +

+Classes

struct  MLXConvParams< NDIM >
 
struct  mlx::steel::ImplicitGemmConv2DParams
 
struct  mlx::steel::Conv2DGeneralJumpParams
 
struct  mlx::steel::Conv2DGeneralBaseInfo
 
+ + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::steel
 
+
+ + + + diff --git a/docs/build/html/conv_2params_8h_source.html b/docs/build/html/conv_2params_8h_source.html new file mode 100644 index 000000000..deb4bf9a4 --- /dev/null +++ b/docs/build/html/conv_2params_8h_source.html @@ -0,0 +1,202 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/params.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
params.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5template <int NDIM>
+
+ +
7 const int N; // Batch size
+
8 const int C; // In channels
+
9 const int O; // Out channels
+
10 const int iS[NDIM]; // Input spatial dim
+
11 const int wS[NDIM]; // Weight spatial dim
+
12 const int oS[NDIM]; // Output spatial dim
+
13 const int str[NDIM]; // Kernel strides
+
14 const int pad[NDIM]; // Input padding
+
15 const int kdil[NDIM]; // Kernel dilation
+
16 const int idil[NDIM]; // Input dilation
+
17 const size_t in_strides[NDIM + 2]; // In strides
+
18 const size_t wt_strides[NDIM + 2]; // Wt strides
+
19 const size_t out_strides[NDIM + 2]; // Out strides
+
20 const int groups; // Input channel groups
+
21 const bool flip;
+
22};
+
+
23
+
24namespace mlx {
+
25namespace steel {
+
26
+
+ +
28 const int M;
+
29 const int N;
+
30 const int K;
+
31
+ +
33
+
34 const int inp_jump_w;
+
35 const int inp_jump_h;
+
36 const int inp_jump_c;
+
37
+
38 const int tiles_n;
+
39 const int tiles_m;
+
40 const int swizzle_log;
+
41};
+
+
42
+
+ +
44 const int f_wgt_jump_h;
+
45 const int f_wgt_jump_w;
+
46
+
47 const int f_out_jump_h;
+
48 const int f_out_jump_w;
+
49
+
50 const int adj_out_h;
+
51 const int adj_out_w;
+
52 const int adj_out_hw;
+
53 const int adj_implicit_m;
+
54};
+
+
55
+
+ + + +
59};
+
+
60
+
61} // namespace steel
+
62} // namespace mlx
+
Definition allocator.h:7
+
Definition params.h:6
+
const int C
Definition params.h:8
+
const size_t out_strides[NDIM+2]
Definition params.h:19
+
const int oS[NDIM]
Definition params.h:12
+
const int iS[NDIM]
Definition params.h:10
+
const int kdil[NDIM]
Definition params.h:15
+
const int str[NDIM]
Definition params.h:13
+
const size_t wt_strides[NDIM+2]
Definition params.h:18
+
const bool flip
Definition params.h:21
+
const size_t in_strides[NDIM+2]
Definition params.h:17
+
const int wS[NDIM]
Definition params.h:11
+
const int O
Definition params.h:9
+
const int N
Definition params.h:7
+
const int pad[NDIM]
Definition params.h:14
+
const int groups
Definition params.h:20
+
const int idil[NDIM]
Definition params.h:16
+
Definition params.h:56
+
int weight_base
Definition params.h:57
+
int weight_size
Definition params.h:58
+ +
const int f_out_jump_w
Definition params.h:48
+
const int f_wgt_jump_h
Definition params.h:44
+
const int f_wgt_jump_w
Definition params.h:45
+
const int adj_implicit_m
Definition params.h:53
+
const int f_out_jump_h
Definition params.h:47
+
const int adj_out_h
Definition params.h:50
+
const int adj_out_w
Definition params.h:51
+
const int adj_out_hw
Definition params.h:52
+ +
const int inp_jump_h
Definition params.h:35
+
const int M
Definition params.h:28
+
const int N
Definition params.h:29
+
const int tiles_m
Definition params.h:39
+
const int tiles_n
Definition params.h:38
+
const int inp_jump_c
Definition params.h:36
+
const int gemm_k_iterations
Definition params.h:32
+
const int inp_jump_w
Definition params.h:34
+
const int swizzle_log
Definition params.h:40
+
const int K
Definition params.h:30
+
+ + + + diff --git a/docs/build/html/conv_8h.html b/docs/build/html/conv_8h.html new file mode 100644 index 000000000..e62539e7c --- /dev/null +++ b/docs/build/html/conv_8h.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/conv.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
conv.h File Reference
+
+
+
#include "mlx/backend/metal/kernels/steel/utils.h"
+#include "mlx/backend/metal/kernels/steel/conv/loader.h"
+#include "mlx/backend/metal/kernels/steel/conv/params.h"
+
+

Go to the source code of this file.

+
+ + + + diff --git a/docs/build/html/conv_8h_source.html b/docs/build/html/conv_8h_source.html new file mode 100644 index 000000000..a7219c852 --- /dev/null +++ b/docs/build/html/conv_8h_source.html @@ -0,0 +1,108 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/conv.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
conv.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+ +
6
+ + +
9
+
10using namespace metal;
+
11using namespace mlx::steel;
+ + + +
Definition bf16.h:265
+
Definition loader_channel_l.h:14
+
+ + + + diff --git a/docs/build/html/cookie.js b/docs/build/html/cookie.js new file mode 100644 index 000000000..53ad21d98 --- /dev/null +++ b/docs/build/html/cookie.js @@ -0,0 +1,58 @@ +/*! + Cookie helper functions + Copyright (c) 2023 Dimitri van Heesch + Released under MIT license. +*/ +let Cookie = { + cookie_namespace: 'doxygen_', + + readSetting(cookie,defVal) { + if (window.chrome) { + const val = localStorage.getItem(this.cookie_namespace+cookie) || + sessionStorage.getItem(this.cookie_namespace+cookie); + if (val) return val; + } else { + let myCookie = this.cookie_namespace+cookie+"="; + if (document.cookie) { + const index = document.cookie.indexOf(myCookie); + if (index != -1) { + const valStart = index + myCookie.length; + let valEnd = document.cookie.indexOf(";", valStart); + if (valEnd == -1) { + valEnd = document.cookie.length; + } + return document.cookie.substring(valStart, valEnd); + } + } + } + return defVal; + }, + + writeSetting(cookie,val,days=10*365) { // default days='forever', 0=session cookie, -1=delete + if (window.chrome) { + if (days==0) { + sessionStorage.setItem(this.cookie_namespace+cookie,val); + } else { + localStorage.setItem(this.cookie_namespace+cookie,val); + } + } else { + let date = new Date(); + date.setTime(date.getTime()+(days*24*60*60*1000)); + const expiration = days!=0 ? "expires="+date.toGMTString()+";" : ""; + document.cookie = this.cookie_namespace + cookie + "=" + + val + "; SameSite=Lax;" + expiration + "path=/"; + } + }, + + eraseSetting(cookie) { + if (window.chrome) { + if (localStorage.getItem(this.cookie_namespace+cookie)) { + localStorage.removeItem(this.cookie_namespace+cookie); + } else if (sessionStorage.getItem(this.cookie_namespace+cookie)) { + sessionStorage.removeItem(this.cookie_namespace+cookie); + } + } else { + this.writeSetting(cookie,'',-1); + } + }, +} diff --git a/docs/build/html/cpp/ops.html b/docs/build/html/cpp/ops.html index 159c599d7..a43770828 100644 --- a/docs/build/html/cpp/ops.html +++ b/docs/build/html/cpp/ops.html @@ -8,7 +8,7 @@ - Operations — MLX 0.12.0 documentation + Operations — MLX 0.13.0 documentation @@ -36,7 +36,7 @@ - + @@ -44,7 +44,7 @@ - + @@ -131,8 +131,8 @@ - MLX 0.12.0 documentation - Home - + MLX 0.13.0 documentation - Home + @@ -255,6 +255,7 @@
  • mlx.core.arcsin
  • mlx.core.arcsinh
  • mlx.core.arctan
  • +
  • mlx.core.arctan2
  • mlx.core.arctanh
  • mlx.core.argmax
  • mlx.core.argmin
  • @@ -264,11 +265,17 @@
  • mlx.core.atleast_1d
  • mlx.core.atleast_2d
  • mlx.core.atleast_3d
  • -
  • mlx.core.broadcast_to
  • +
  • mlx.core.bitwise_and
  • +
  • mlx.core.bitwise_or
  • +
  • mlx.core.bitwise_xor
  • mlx.core.block_masked_mm
  • +
  • mlx.core.block_sparse_mm
  • +
  • mlx.core.broadcast_to
  • mlx.core.ceil
  • mlx.core.clip
  • mlx.core.concatenate
  • +
  • mlx.core.conj
  • +
  • mlx.core.conjugate
  • mlx.core.convolve
  • mlx.core.conv1d
  • mlx.core.conv2d
  • @@ -305,6 +312,7 @@
  • mlx.core.isnan
  • mlx.core.isneginf
  • mlx.core.isposinf
  • +
  • mlx.core.left_shift
  • mlx.core.less
  • mlx.core.less_equal
  • mlx.core.linspace
  • @@ -341,6 +349,7 @@
  • mlx.core.reciprocal
  • mlx.core.repeat
  • mlx.core.reshape
  • +
  • mlx.core.right_shift
  • mlx.core.round
  • mlx.core.rsqrt
  • mlx.core.save
  • @@ -436,8 +445,10 @@
  • Metal
  • +
  • mlx.optimizers.clip_grad_norm
  • Tree Utils
  • @@ -751,7 +764,9 @@ document.write(` `); - + @@ -767,6 +782,298 @@ document.write(`
    +
    +

    Contents

    +
    +
    @@ -778,6 +1085,1617 @@ document.write(`

    Operations#

    +
    +
    +array arange(double start, double stop, double step, Dtype dtype, StreamOrDevice s = {})#
    +

    A 1D array of numbers starting at start (optional), stopping at stop, stepping by step (optional).

    +
    + +
    +
    +array arange(double start, double stop, double step, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(double start, double stop, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(double stop, Dtype dtype, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(double stop, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(int start, int stop, int step, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(int start, int stop, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(int stop, StreamOrDevice s = {})#
    +
    + +
    +
    +array linspace(double start, double stop, int num = 50, Dtype dtype = float32, StreamOrDevice s = {})#
    +

    A 1D array of num evenly spaced numbers in the range [start, stop]

    +
    + +
    +
    +array astype(array a, Dtype dtype, StreamOrDevice s = {})#
    +

    Convert an array to the given data type.

    +
    + +
    +
    +array as_strided(array a, std::vector<int> shape, std::vector<size_t> strides, size_t offset, StreamOrDevice s = {})#
    +

    Create a view of an array with the given shape and strides.

    +
    + +
    +
    +array copy(array a, StreamOrDevice s = {})#
    +

    Copy another array.

    +
    + +
    +
    +array full(std::vector<int> shape, array vals, Dtype dtype, StreamOrDevice s = {})#
    +

    Fill an array of the given shape with the given value(s).

    +
    + +
    +
    +array full(std::vector<int> shape, array vals, StreamOrDevice s = {})#
    +
    + +
    +
    +template<typename T>
    array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {})#
    +
    + +
    +
    +template<typename T>
    array full(std::vector<int> shape, T val, StreamOrDevice s = {})#
    +
    + +
    +
    +array zeros(const std::vector<int> &shape, Dtype dtype, StreamOrDevice s = {})#
    +

    Fill an array of the given shape with zeros.

    +
    + +
    +
    +inline array zeros(const std::vector<int> &shape, StreamOrDevice s = {})#
    +
    + +
    +
    +array zeros_like(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array ones(const std::vector<int> &shape, Dtype dtype, StreamOrDevice s = {})#
    +

    Fill an array of the given shape with ones.

    +
    + +
    +
    +inline array ones(const std::vector<int> &shape, StreamOrDevice s = {})#
    +
    + +
    +
    +array ones_like(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {})#
    +

    Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere else.

    +
    + +
    +
    +inline array eye(int n, Dtype dtype, StreamOrDevice s = {})#
    +
    + +
    +
    +inline array eye(int n, int m, StreamOrDevice s = {})#
    +
    + +
    +
    +inline array eye(int n, int m, int k, StreamOrDevice s = {})#
    +
    + +
    +
    +inline array eye(int n, StreamOrDevice s = {})#
    +
    + +
    +
    +array identity(int n, Dtype dtype, StreamOrDevice s = {})#
    +

    Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.

    +
    + +
    +
    +inline array identity(int n, StreamOrDevice s = {})#
    +
    + +
    +
    +array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {})#
    +
    + +
    +
    +inline array tri(int n, Dtype type, StreamOrDevice s = {})#
    +
    + +
    +
    +array tril(array x, int k = 0, StreamOrDevice s = {})#
    +
    + +
    +
    +array triu(array x, int k = 0, StreamOrDevice s = {})#
    +
    + +
    +
    +array reshape(const array &a, std::vector<int> shape, StreamOrDevice s = {})#
    +

    Reshape an array to the given shape.

    +
    + +
    +
    +array flatten(const array &a, int start_axis, int end_axis = -1, StreamOrDevice s = {})#
    +

    Flatten the dimensions in the range [start_axis, end_axis] .

    +
    + +
    +
    +array flatten(const array &a, StreamOrDevice s = {})#
    +

    Flatten the array to 1D.

    +
    + +
    +
    +array squeeze(const array &a, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Remove singleton dimensions at the given axes.

    +
    + +
    +
    +inline array squeeze(const array &a, int axis, StreamOrDevice s = {})#
    +

    Remove singleton dimensions at the given axis.

    +
    + +
    +
    +array squeeze(const array &a, StreamOrDevice s = {})#
    +

    Remove all singleton dimensions.

    +
    + +
    +
    +array expand_dims(const array &a, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Add a singleton dimension at the given axes.

    +
    + +
    +
    +array expand_dims(const array &a, int axis, StreamOrDevice s = {})#
    +

    Add a singleton dimension at the given axis.

    +
    + +
    +
    +array slice(const array &a, std::vector<int> start, std::vector<int> stop, std::vector<int> strides, StreamOrDevice s = {})#
    +

    Slice an array.

    +
    + +
    +
    +array slice(const array &a, const std::vector<int> &start, const std::vector<int> &stop, StreamOrDevice s = {})#
    +

    Slice an array with a stride of 1 in each dimension.

    +
    + +
    +
    +array slice_update(const array &src, const array &update, std::vector<int> start, std::vector<int> stop, std::vector<int> strides, StreamOrDevice s = {})#
    +

    Update a slice from the source array.

    +
    + +
    +
    +array slice_update(const array &src, const array &update, std::vector<int> start, std::vector<int> stop, StreamOrDevice s = {})#
    +

    Update a slice from the source array with stride 1 in each dimension.

    +
    + +
    +
    +std::vector<array> split(const array &a, int num_splits, int axis, StreamOrDevice s = {})#
    +

    Split an array into sub-arrays along a given axis.

    +
    + +
    +
    +std::vector<array> split(const array &a, int num_splits, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> split(const array &a, const std::vector<int> &indices, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> split(const array &a, const std::vector<int> &indices, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> meshgrid(const std::vector<array> &arrays, bool sparse = false, std::string indexing = "xy", StreamOrDevice s = {})#
    +

    A vector of coordinate arrays from coordinate vectors.

    +
    + +
    +
    +array clip(const array &a, const std::optional<array> &a_min = std::nullopt, const std::optional<array> &a_max = std::nullopt, StreamOrDevice s = {})#
    +

    Clip (limit) the values in an array.

    +
    + +
    +
    +array concatenate(const std::vector<array> &arrays, int axis, StreamOrDevice s = {})#
    +

    Concatenate arrays along a given axis.

    +
    + +
    +
    +array concatenate(const std::vector<array> &arrays, StreamOrDevice s = {})#
    +
    + +
    +
    +array stack(const std::vector<array> &arrays, int axis, StreamOrDevice s = {})#
    +

    Stack arrays along a new axis.

    +
    + +
    +
    +array stack(const std::vector<array> &arrays, StreamOrDevice s = {})#
    +
    + +
    +
    +array repeat(const array &arr, int repeats, int axis, StreamOrDevice s = {})#
    +

    Repeat an array along an axis.

    +
    + +
    +
    +array repeat(const array &arr, int repeats, StreamOrDevice s = {})#
    +
    + +
    +
    +array tile(const array &arr, std::vector<int> reps, StreamOrDevice s = {})#
    +
    + +
    +
    +array transpose(const array &a, std::vector<int> axes, StreamOrDevice s = {})#
    +

    Permutes the dimensions according to the given axes.

    +
    + +
    +
    +inline array transpose(const array &a, std::initializer_list<int> axes, StreamOrDevice s = {})#
    +
    + +
    +
    +array swapaxes(const array &a, int axis1, int axis2, StreamOrDevice s = {})#
    +

    Swap two axes of an array.

    +
    + +
    +
    +array moveaxis(const array &a, int source, int destination, StreamOrDevice s = {})#
    +

    Move an axis of an array.

    +
    + +
    +
    +array pad(const array &a, const std::vector<int> &axes, const std::vector<int> &low_pad_size, const std::vector<int> &high_pad_size, const array &pad_value = array(0), StreamOrDevice s = {})#
    +

    Pad an array with a constant value.

    +
    + +
    +
    +array pad(const array &a, const std::vector<std::pair<int, int>> &pad_width, const array &pad_value = array(0), StreamOrDevice s = {})#
    +

    Pad an array with a constant value along all axes.

    +
    + +
    +
    +array pad(const array &a, const std::pair<int, int> &pad_width, const array &pad_value = array(0), StreamOrDevice s = {})#
    +
    + +
    +
    +array pad(const array &a, int pad_width, const array &pad_value = array(0), StreamOrDevice s = {})#
    +
    + +
    +
    +array transpose(const array &a, StreamOrDevice s = {})#
    +

    Permutes the dimensions in reverse order.

    +
    + +
    +
    +array broadcast_to(const array &a, const std::vector<int> &shape, StreamOrDevice s = {})#
    +

    Broadcast an array to a given shape.

    +
    + +
    +
    +std::vector<array> broadcast_arrays(const std::vector<array> &inputs, StreamOrDevice s = {})#
    +

    Broadcast a vector of arrays against one another.

    +
    + +
    +
    +array equal(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns the bool array with (a == b) element-wise.

    +
    + +
    +
    +inline array operator==(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator==(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator==(const array &a, T b)#
    +
    + +
    +
    +array not_equal(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns the bool array with (a != b) element-wise.

    +
    + +
    +
    +inline array operator!=(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator!=(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator!=(const array &a, T b)#
    +
    + +
    +
    +array greater(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns bool array with (a > b) element-wise.

    +
    + +
    +
    +inline array operator>(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator>(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator>(const array &a, T b)#
    +
    + +
    +
    +array greater_equal(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns bool array with (a >= b) element-wise.

    +
    + +
    +
    +inline array operator>=(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator>=(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator>=(const array &a, T b)#
    +
    + +
    +
    +array less(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns bool array with (a < b) element-wise.

    +
    + +
    +
    +inline array operator<(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator<(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator<(const array &a, T b)#
    +
    + +
    +
    +array less_equal(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns bool array with (a <= b) element-wise.

    +
    + +
    +
    +inline array operator<=(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator<=(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator<=(const array &a, T b)#
    +
    + +
    +
    +array array_equal(const array &a, const array &b, bool equal_nan, StreamOrDevice s = {})#
    +

    True if two arrays have the same shape and elements.

    +
    + +
    +
    +inline array array_equal(const array &a, const array &b, StreamOrDevice s = {})#
    +
    + +
    +
    +array isnan(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array isinf(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array isposinf(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array isneginf(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array where(const array &condition, const array &x, const array &y, StreamOrDevice s = {})#
    +

    Select from x or y depending on condition.

    +
    + +
    +
    +array all(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    True if all elements in the array are true (or non-zero).

    +
    + +
    +
    +inline array all(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array allclose(const array &a, const array &b, double rtol = 1e-5, double atol = 1e-8, bool equal_nan = false, StreamOrDevice s = {})#
    +

    True if the two arrays are equal within the specified tolerance.

    +
    + +
    +
    +array isclose(const array &a, const array &b, double rtol = 1e-5, double atol = 1e-8, bool equal_nan = false, StreamOrDevice s = {})#
    +

    Returns a boolean array where two arrays are element-wise equal within the specified tolerance.

    +
    + +
    +
    +array all(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    Reduces the input along the given axes.

    +

    An output value is true if all the corresponding inputs are true.

    +
    + +
    +
    +array all(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Reduces the input along the given axis.

    +

    An output value is true if all the corresponding inputs are true.

    +
    + +
    +
    +array any(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    True if any elements in the array are true (or non-zero).

    +
    + +
    +
    +inline array any(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array any(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    Reduces the input along the given axes.

    +

    An output value is true if any of the corresponding inputs are true.

    +
    + +
    +
    +array any(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Reduces the input along the given axis.

    +

    An output value is true if any of the corresponding inputs are true.

    +
    + +
    +
    +array sum(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    Sums the elements of an array.

    +
    + +
    +
    +inline array sum(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array sum(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    Sums the elements of an array along the given axes.

    +
    + +
    +
    +array sum(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Sums the elements of an array along the given axis.

    +
    + +
    +
    +array mean(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    Computes the mean of the elements of an array.

    +
    + +
    +
    +inline array mean(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array mean(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    Computes the mean of the elements of an array along the given axes.

    +
    + +
    +
    +array mean(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Computes the mean of the elements of an array along the given axis.

    +
    + +
    +
    +array var(const array &a, bool keepdims, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the variance of the elements of an array.

    +
    + +
    +
    +inline array var(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array var(const array &a, const std::vector<int> &axes, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the variance of the elements of an array along the given axes.

    +
    + +
    +
    +array var(const array &a, int axis, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the variance of the elements of an array along the given axis.

    +
    + +
    +
    +array std(const array &a, bool keepdims, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the standard deviation of the elements of an array.

    +
    + +
    +
    +inline array std(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array std(const array &a, const std::vector<int> &axes, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the standard deviatoin of the elements of an array along the given axes.

    +
    + +
    +
    +array std(const array &a, int axis, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the standard deviation of the elements of an array along the given axis.

    +
    + +
    +
    +array prod(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    The product of all elements of the array.

    +
    + +
    +
    +inline array prod(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array prod(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    The product of the elements of an array along the given axes.

    +
    + +
    +
    +array prod(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    The product of the elements of an array along the given axis.

    +
    + +
    +
    +array max(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    The maximum of all elements of the array.

    +
    + +
    +
    +inline array max(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array max(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    The maximum of the elements of an array along the given axes.

    +
    + +
    +
    +array max(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    The maximum of the elements of an array along the given axis.

    +
    + +
    +
    +array min(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    The minimum of all elements of the array.

    +
    + +
    +
    +inline array min(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array min(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    The minimum of the elements of an array along the given axes.

    +
    + +
    +
    +array min(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    The minimum of the elements of an array along the given axis.

    +
    + +
    +
    +array argmin(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    Returns the index of the minimum value in the array.

    +
    + +
    +
    +inline array argmin(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array argmin(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Returns the indices of the minimum values along a given axis.

    +
    + +
    +
    +array argmax(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    Returns the index of the maximum value in the array.

    +
    + +
    +
    +inline array argmax(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array argmax(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Returns the indices of the maximum values along a given axis.

    +
    + +
    +
    +array sort(const array &a, StreamOrDevice s = {})#
    +

    Returns a sorted copy of the flattened array.

    +
    + +
    +
    +array sort(const array &a, int axis, StreamOrDevice s = {})#
    +

    Returns a sorted copy of the array along a given axis.

    +
    + +
    +
    +array argsort(const array &a, StreamOrDevice s = {})#
    +

    Returns indices that sort the flattened array.

    +
    + +
    +
    +array argsort(const array &a, int axis, StreamOrDevice s = {})#
    +

    Returns indices that sort the array along a given axis.

    +
    + +
    +
    +array partition(const array &a, int kth, StreamOrDevice s = {})#
    +

    Returns a partitioned copy of the flattened array such that the smaller kth elements are first.

    +
    + +
    +
    +array partition(const array &a, int kth, int axis, StreamOrDevice s = {})#
    +

    Returns a partitioned copy of the array along a given axis such that the smaller kth elements are first.

    +
    + +
    +
    +array argpartition(const array &a, int kth, StreamOrDevice s = {})#
    +

    Returns indices that partition the flattened array such that the smaller kth elements are first.

    +
    + +
    +
    +array argpartition(const array &a, int kth, int axis, StreamOrDevice s = {})#
    +

    Returns indices that partition the array along a given axis such that the smaller kth elements are first.

    +
    + +
    +
    +array topk(const array &a, int k, StreamOrDevice s = {})#
    +

    Returns topk elements of the flattened array.

    +
    + +
    +
    +array topk(const array &a, int k, int axis, StreamOrDevice s = {})#
    +

    Returns topk elements of the array along a given axis.

    +
    + +
    +
    +array logsumexp(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    The logsumexp of all elements of the array.

    +
    + +
    +
    +inline array logsumexp(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array logsumexp(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    The logsumexp of the elements of an array along the given axes.

    +
    + +
    +
    +array logsumexp(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    The logsumexp of the elements of an array along the given axis.

    +
    + +
    +
    +array abs(const array &a, StreamOrDevice s = {})#
    +

    Absolute value of elements in an array.

    +
    + +
    +
    +array negative(const array &a, StreamOrDevice s = {})#
    +

    Negate an array.

    +
    + +
    +
    +array operator-(const array &a)#
    +
    + +
    +
    +array sign(const array &a, StreamOrDevice s = {})#
    +

    The sign of the elements in an array.

    +
    + +
    +
    +array logical_not(const array &a, StreamOrDevice s = {})#
    +

    Logical not of an array.

    +
    + +
    +
    +array logical_and(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Logical and of two arrays.

    +
    + +
    +
    +array operator&&(const array &a, const array &b)#
    +
    + +
    +
    +array logical_or(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Logical or of two arrays.

    +
    + +
    +
    +array operator||(const array &a, const array &b)#
    +
    + +
    +
    +array reciprocal(const array &a, StreamOrDevice s = {})#
    +

    The reciprocal (1/x) of the elements in an array.

    +
    + +
    +
    +array add(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Add two arrays.

    +
    + +
    +
    +array operator+(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator+(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator+(const array &a, T b)#
    +
    + +
    +
    +array subtract(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Subtract two arrays.

    +
    + +
    +
    +array operator-(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator-(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator-(const array &a, T b)#
    +
    + +
    +
    +array multiply(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Multiply two arrays.

    +
    + +
    +
    +array operator*(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator*(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator*(const array &a, T b)#
    +
    + +
    +
    +array divide(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Divide two arrays.

    +
    + +
    +
    +array operator/(const array &a, const array &b)#
    +
    + +
    +
    +array operator/(double a, const array &b)#
    +
    + +
    +
    +array operator/(const array &a, double b)#
    +
    + +
    +
    +std::vector<array> divmod(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute the element-wise quotient and remainder.

    +
    + +
    +
    +array floor_divide(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute integer division.

    +

    Equivalent to doing floor(a / x).

    +
    + +
    +
    +array remainder(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute the element-wise remainder of division.

    +
    + +
    +
    +array operator%(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator%(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator%(const array &a, T b)#
    +
    + +
    +
    +array maximum(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Element-wise maximum between two arrays.

    +
    + +
    +
    +array minimum(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Element-wise minimum between two arrays.

    +
    + +
    +
    +array floor(const array &a, StreamOrDevice s = {})#
    +

    Floor the element of an array.

    +
    + +
    +
    +array ceil(const array &a, StreamOrDevice s = {})#
    +

    Ceil the element of an array.

    +
    + +
    +
    +array square(const array &a, StreamOrDevice s = {})#
    +

    Square the elements of an array.

    +
    + +
    +
    +array exp(const array &a, StreamOrDevice s = {})#
    +

    Exponential of the elements of an array.

    +
    + +
    +
    +array sin(const array &a, StreamOrDevice s = {})#
    +

    Sine of the elements of an array.

    +
    + +
    +
    +array cos(const array &a, StreamOrDevice s = {})#
    +

    Cosine of the elements of an array.

    +
    + +
    +
    +array tan(const array &a, StreamOrDevice s = {})#
    +

    Tangent of the elements of an array.

    +
    + +
    +
    +array arcsin(const array &a, StreamOrDevice s = {})#
    +

    Arc Sine of the elements of an array.

    +
    + +
    +
    +array arccos(const array &a, StreamOrDevice s = {})#
    +

    Arc Cosine of the elements of an array.

    +
    + +
    +
    +array arctan(const array &a, StreamOrDevice s = {})#
    +

    Arc Tangent of the elements of an array.

    +
    + +
    +
    +array arctan2(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Inverse tangent of the ratio of two arrays.

    +
    + +
    +
    +array sinh(const array &a, StreamOrDevice s = {})#
    +

    Hyperbolic Sine of the elements of an array.

    +
    + +
    +
    +array cosh(const array &a, StreamOrDevice s = {})#
    +

    Hyperbolic Cosine of the elements of an array.

    +
    + +
    +
    +array tanh(const array &a, StreamOrDevice s = {})#
    +

    Hyperbolic Tangent of the elements of an array.

    +
    + +
    +
    +array arcsinh(const array &a, StreamOrDevice s = {})#
    +

    Inverse Hyperbolic Sine of the elements of an array.

    +
    + +
    +
    +array arccosh(const array &a, StreamOrDevice s = {})#
    +

    Inverse Hyperbolic Cosine of the elements of an array.

    +
    + +
    +
    +array arctanh(const array &a, StreamOrDevice s = {})#
    +

    Inverse Hyperbolic Tangent of the elements of an array.

    +
    + +
    +
    +array degrees(const array &a, StreamOrDevice s = {})#
    +

    Convert the elements of an array from Radians to Degrees.

    +
    + +
    +
    +array radians(const array &a, StreamOrDevice s = {})#
    +

    Convert the elements of an array from Degrees to Radians.

    +
    + +
    +
    +array log(const array &a, StreamOrDevice s = {})#
    +

    Natural logarithm of the elements of an array.

    +
    + +
    +
    +array log2(const array &a, StreamOrDevice s = {})#
    +

    Log base 2 of the elements of an array.

    +
    + +
    +
    +array log10(const array &a, StreamOrDevice s = {})#
    +

    Log base 10 of the elements of an array.

    +
    + +
    +
    +array log1p(const array &a, StreamOrDevice s = {})#
    +

    Natural logarithm of one plus elements in the array: log(1 + a).

    +
    + +
    +
    +array logaddexp(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Log-add-exp of one elements in the array: log(exp(a) + exp(b)).

    +
    + +
    +
    +array sigmoid(const array &a, StreamOrDevice s = {})#
    +

    Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).

    +
    + +
    +
    +array erf(const array &a, StreamOrDevice s = {})#
    +

    Computes the error function of the elements of an array.

    +
    + +
    +
    +array erfinv(const array &a, StreamOrDevice s = {})#
    +

    Computes the inverse error function of the elements of an array.

    +
    + +
    +
    +array expm1(const array &a, StreamOrDevice s = {})#
    +

    Computes the expm1 function of the elements of an array.

    +
    + +
    +
    +array stop_gradient(const array &a, StreamOrDevice s = {})#
    +

    Stop the flow of gradients.

    +
    + +
    +
    +array round(const array &a, int decimals, StreamOrDevice s = {})#
    +

    Round a floating point number.

    +
    + +
    +
    +inline array round(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array matmul(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Matrix-matrix multiplication.

    +
    + +
    +
    +array gather(const array &a, const std::vector<array> &indices, const std::vector<int> &axes, const std::vector<int> &slice_sizes, StreamOrDevice s = {})#
    +

    Gather array entries given indices and slices.

    +
    + +
    +
    +inline array gather(const array &a, const array &indices, int axis, const std::vector<int> &slice_sizes, StreamOrDevice s = {})#
    +
    + +
    +
    +array take(const array &a, const array &indices, int axis, StreamOrDevice s = {})#
    +

    Take array slices at the given indices of the specified axis.

    +
    + +
    +
    +array take(const array &a, const array &indices, StreamOrDevice s = {})#
    +

    Take array entries at the given indices treating the array as flattened.

    +
    + +
    +
    +array take_along_axis(const array &a, const array &indices, int axis, StreamOrDevice s = {})#
    +

    Take array entries given indices along the axis.

    +
    + +
    +
    +array scatter(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter updates to given linear indices.

    +
    + +
    +
    +inline array scatter(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array scatter_add(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter and add updates to given indices.

    +
    + +
    +
    +inline array scatter_add(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array scatter_prod(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter and prod updates to given indices.

    +
    + +
    +
    +inline array scatter_prod(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array scatter_max(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter and max updates to given linear indices.

    +
    + +
    +
    +inline array scatter_max(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array scatter_min(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter and min updates to given linear indices.

    +
    + +
    +
    +inline array scatter_min(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array sqrt(const array &a, StreamOrDevice s = {})#
    +

    Square root the elements of an array.

    +
    + +
    +
    +array rsqrt(const array &a, StreamOrDevice s = {})#
    +

    Square root and reciprocal the elements of an array.

    +
    + +
    +
    +array softmax(const array &a, const std::vector<int> &axes, bool precise = false, StreamOrDevice s = {})#
    +

    Softmax of an array.

    +
    + +
    +
    +array softmax(const array &a, bool precise = false, StreamOrDevice s = {})#
    +

    Softmax of an array.

    +
    + +
    +
    +inline array softmax(const array &a, int axis, bool precise = false, StreamOrDevice s = {})#
    +

    Softmax of an array.

    +
    + +
    +
    +array power(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Raise elements of a to the power of b element-wise.

    +
    + +
    +
    +array cumsum(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
    +

    Cumulative sum of an array.

    +
    + +
    +
    +array cumprod(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
    +

    Cumulative product of an array.

    +
    + +
    +
    +array cummax(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
    +

    Cumulative max of an array.

    +
    + +
    +
    +array cummin(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
    +

    Cumulative min of an array.

    +
    + +
    +
    +array conv_general(array input, array weight, std::vector<int> stride = {}, std::vector<int> padding_lo = {}, std::vector<int> padding_hi = {}, std::vector<int> kernel_dilation = {}, std::vector<int> input_dilation = {}, int groups = 1, bool flip = false, StreamOrDevice s = {})#
    +

    General convolution with a filter.

    +
    + +
    +
    +inline array conv_general(const array &input, const array &weight, std::vector<int> stride = {}, std::vector<int> padding = {}, std::vector<int> kernel_dilation = {}, std::vector<int> input_dilation = {}, int groups = 1, bool flip = false, StreamOrDevice s = {})#
    +

    General convolution with a filter.

    +
    + +
    +
    +array conv1d(const array &input, const array &weight, int stride = 1, int padding = 0, int dilation = 1, int groups = 1, StreamOrDevice s = {})#
    +

    1D convolution with a filter

    +
    + +
    +
    +array conv2d(const array &input, const array &weight, const std::pair<int, int> &stride = {1, 1}, const std::pair<int, int> &padding = {0, 0}, const std::pair<int, int> &dilation = {1, 1}, int groups = 1, StreamOrDevice s = {})#
    +

    2D convolution with a filter

    +
    + +
    +
    +array quantized_matmul(const array &x, const array &w, const array &scales, const array &biases, bool transpose = true, int group_size = 64, int bits = 4, StreamOrDevice s = {})#
    +

    Quantized matmul multiplies x with a quantized matrix w.

    +
    + +
    +
    +std::tuple<array, array, array> quantize(const array &w, int group_size = 64, int bits = 4, StreamOrDevice s = {})#
    +

    Quantize a matrix along its last axis.

    +
    + +
    +
    +array dequantize(const array &w, const array &scales, const array &biases, int group_size = 64, int bits = 4, StreamOrDevice s = {})#
    +

    Dequantize a matrix produced by quantize()

    +
    + +
    +
    +array tensordot(const array &a, const array &b, const int axis = 2, StreamOrDevice s = {})#
    +

    Returns a contraction of a and b over multiple dimensions.

    +
    + +
    +
    +array tensordot(const array &a, const array &b, const std::vector<int> &axes_a, const std::vector<int> &axes_b, StreamOrDevice s = {})#
    +
    + +
    +
    +array outer(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute the outer product of two vectors.

    +
    + +
    +
    +array inner(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute the inner product of two vectors.

    +
    + +
    +
    +array addmm(array c, array a, array b, const float &alpha = 1.f, const float &beta = 1.f, StreamOrDevice s = {})#
    +

    Compute D = beta * C + alpha * (A @ B)

    +
    + +
    +
    +array block_masked_mm(array a, array b, int block_size, std::optional<array> mask_out = std::nullopt, std::optional<array> mask_lhs = std::nullopt, std::optional<array> mask_rhs = std::nullopt, StreamOrDevice s = {})#
    +

    Compute matrix product with block masking.

    +
    + +
    +
    +array block_sparse_mm(array a, array b, std::optional<array> lhs_indices = std::nullopt, std::optional<array> rhs_indices = std::nullopt, StreamOrDevice s = {})#
    +

    Compute matrix product with matrix-level gather.

    +
    + +
    +
    +array diagonal(const array &a, int offset = 0, int axis1 = 0, int axis2 = 1, StreamOrDevice s = {})#
    +

    Extract a diagonal or construct a diagonal array.

    +
    + +
    +
    +array diag(const array &a, int k = 0, StreamOrDevice s = {})#
    +

    Extract diagonal from a 2d array or create a diagonal matrix.

    +
    + +
    +
    +std::vector<array> depends(const std::vector<array> &inputs, const std::vector<array> &dependencies)#
    +

    Implements the identity function but allows injecting dependencies to other arrays.

    +

    This ensures that these other arrays will have been computed when the outputs of this function are computed.

    +
    + +
    +
    +array atleast_1d(const array &a, StreamOrDevice s = {})#
    +

    convert an array to an atleast ndim array

    +
    + +
    +
    +std::vector<array> atleast_1d(const std::vector<array> &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array atleast_2d(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> atleast_2d(const std::vector<array> &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array atleast_3d(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> atleast_3d(const std::vector<array> &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array number_of_elements(const array &a, std::vector<int> axes, bool inverted, Dtype dtype = int32, StreamOrDevice s = {})#
    +

    Extract the number of elements along some axes as a scalar array.

    +

    Used to allow shape dependent shapeless compilation (pun intended).

    +
    + +
    +
    +array conjugate(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array bitwise_and(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Bitwise and.

    +
    + +
    +
    +array operator&(const array &a, const array &b)#
    +
    + +
    +
    +array bitwise_or(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Bitwise inclusive or.

    +
    + +
    +
    +array operator|(const array &a, const array &b)#
    +
    + +
    +
    +array bitwise_xor(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Bitwise exclusive or.

    +
    + +
    +
    +array operator^(const array &a, const array &b)#
    +
    + +
    +
    +array left_shift(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Shift bits to the left.

    +
    + +
    +
    +array operator<<(const array &a, const array &b)#
    +
    + +
    +
    +array right_shift(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Shift bits to the right.

    +
    + +
    +
    +array operator>>(const array &a, const array &b)#
    +
    +
    @@ -792,12 +2710,12 @@ document.write(`

    previous

    -

    mlx.utils.tree_map_with_path

    +

    mlx.utils.tree_reduce

    +