@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Abs::output_shapes |
+ std::vector< Shape > mlx::core::Abs::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_add-members.html b/docs/build/html/classmlx_1_1core_1_1_add-members.html
index 2b045f851..ce8ca14c6 100644
--- a/docs/build/html/classmlx_1_1core_1_1_add-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_add-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Add | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Add | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_add.html b/docs/build/html/classmlx_1_1core_1_1_add.html
index 81961d64d..9dd674b5b 100644
--- a/docs/build/html/classmlx_1_1core_1_1_add.html
+++ b/docs/build/html/classmlx_1_1core_1_1_add.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Add::output_shapes |
+ std::vector< Shape > mlx::core::Add::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_add_m_m-members.html b/docs/build/html/classmlx_1_1core_1_1_add_m_m-members.html
index 113fe22cc..cb669de52 100644
--- a/docs/build/html/classmlx_1_1core_1_1_add_m_m-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_add_m_m-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_add_m_m.html b/docs/build/html/classmlx_1_1core_1_1_add_m_m.html
index edf52820c..2a46960cb 100644
--- a/docs/build/html/classmlx_1_1core_1_1_add_m_m.html
+++ b/docs/build/html/classmlx_1_1core_1_1_add_m_m.html
@@ -158,9 +158,9 @@ Public Member Functions
virtual std::vector< array > | jvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) |
| The Jacobian-vector product.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arange-members.html b/docs/build/html/classmlx_1_1core_1_1_arange-members.html
index c2f24fac4..c6db8da88 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arange-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arange-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Arange | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arange.html b/docs/build/html/classmlx_1_1core_1_1_arange.html
index cbea55a54..4b20784c5 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arange.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arange.html
@@ -121,6 +121,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -158,9 +161,6 @@ Public Member Functions
virtual std::pair< std::vector< array >, std::vector< int > > | vmap (const std::vector< array > &inputs, const std::vector< int > &axes) |
| The primitive must know how to vectorize itself across the given axes.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
@@ -302,6 +302,36 @@ Public Member Functions
Reimplemented from mlx::core::Primitive.
+
+
+
+◆ output_shapes()
+
+
+
+
+
+
+
+
+ std::vector< Shape > mlx::core::Arange::output_shapes |
+ ( |
+ const std::vector< array > & | inputs | ) |
+ |
+
+
+ |
+
+overridevirtual |
+
+
+
+
+ Get the output shapes of the primitive.
+ This is not required to be implemented by derived classes, in which case it will throw.
+
+ Reimplemented from mlx::core::Primitive.
+
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html
index 2c1322299..b313ef0e8 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcCos | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcCos | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cos.html b/docs/build/html/classmlx_1_1core_1_1_arc_cos.html
index d829fb87d..7f4d285d3 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_cos.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_cos.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArcCos::output_shapes |
+ std::vector< Shape > mlx::core::ArcCos::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html
index 8f438f92d..1c130de7a 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcCosh | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcCosh | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html b/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html
index 35b069f94..0838fbc92 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArcCosh::output_shapes |
+ std::vector< Shape > mlx::core::ArcCosh::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sin-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_sin-members.html
index 7459baa07..e8dc0ce13 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_sin-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_sin-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcSin | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcSin | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sin.html b/docs/build/html/classmlx_1_1core_1_1_arc_sin.html
index 5940f522b..335eefc4c 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_sin.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_sin.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArcSin::output_shapes |
+ std::vector< Shape > mlx::core::ArcSin::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sinh-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_sinh-members.html
index 816670a36..42b97c8f1 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_sinh-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_sinh-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcSinh | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcSinh | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html b/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html
index cf5a35ea2..ad83db2b4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArcSinh::output_shapes |
+ std::vector< Shape > mlx::core::ArcSinh::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan-members.html
index c0796ec85..aa6b5150b 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_tan-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcTan | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcTan | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan.html
index 537cbc1d5..11ffb3f01 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_tan.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArcTan::output_shapes |
+ std::vector< Shape > mlx::core::ArcTan::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html
index bcda7fddd..a6f32e13d 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcTan2 | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcTan2 | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html
index ca7dc2bbb..d02e0b6c4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArcTan2::output_shapes |
+ std::vector< Shape > mlx::core::ArcTan2::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html
index 08f644f5e..cff6999ed 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcTanh | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArcTanh | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html b/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html
index 084d4457d..adc5253c2 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArcTanh::output_shapes |
+ std::vector< Shape > mlx::core::ArcTanh::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html b/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html
index e7d6f328e..e708f7fc3 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArgPartition | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArgPartition | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_partition.html b/docs/build/html/classmlx_1_1core_1_1_arg_partition.html
index e1c91531d..04b2e40d4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arg_partition.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arg_partition.html
@@ -127,9 +127,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -337,8 +337,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -347,7 +347,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArgPartition::output_shapes |
+ std::vector< Shape > mlx::core::ArgPartition::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -363,7 +363,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_reduce-members.html b/docs/build/html/classmlx_1_1core_1_1_arg_reduce-members.html
index 4d1c80263..a86a23fbd 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arg_reduce-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arg_reduce-members.html
@@ -108,7 +108,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArgReduce | virtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArgReduce | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html b/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html
index 2175da61e..d4115d622 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html
@@ -138,9 +138,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -364,8 +364,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -374,7 +374,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArgReduce::output_shapes |
+ std::vector< Shape > mlx::core::ArgReduce::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -390,7 +390,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_sort-members.html b/docs/build/html/classmlx_1_1core_1_1_arg_sort-members.html
index 94f49dfe6..2eec3e178 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arg_sort-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arg_sort-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ArgSort | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ArgSort | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_sort.html b/docs/build/html/classmlx_1_1core_1_1_arg_sort.html
index e7ebb6d83..66b002033 100644
--- a/docs/build/html/classmlx_1_1core_1_1_arg_sort.html
+++ b/docs/build/html/classmlx_1_1core_1_1_arg_sort.html
@@ -121,9 +121,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -294,8 +294,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -304,7 +304,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ArgSort::output_shapes |
+ std::vector< Shape > mlx::core::ArgSort::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -320,7 +320,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html b/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html
index a7908d156..71ddfc428 100644
--- a/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html
@@ -94,7 +94,7 @@ $(function(){ initResizable(false); });
This is the complete list of members for mlx::core::AsStrided, including all inherited members.
- AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset) | mlx::core::AsStrided | inlineexplicit |
+ AsStrided(Stream stream, Shape shape, Strides strides, size_t offset) | mlx::core::AsStrided | inlineexplicit |
device() | mlx::core::Primitive | inline |
eval_cpu(const std::vector< array > &inputs, array &out) override | mlx::core::AsStrided | virtual |
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override | mlx::core::UnaryPrimitive | inlinevirtual |
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_as_strided.html b/docs/build/html/classmlx_1_1core_1_1_as_strided.html
index eafaa0fbd..d8e0f8306 100644
--- a/docs/build/html/classmlx_1_1core_1_1_as_strided.html
+++ b/docs/build/html/classmlx_1_1core_1_1_as_strided.html
@@ -109,8 +109,8 @@ Inheritance diagram for mlx::core::AsStrided:
|
- | AsStrided (Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset) |
- |
+ | AsStrided (Stream stream, Shape shape, Strides strides, size_t offset) |
+ |
void | eval_cpu (const std::vector< array > &inputs, array &out) override |
|
void | eval_gpu (const std::vector< array > &inputs, array &out) override |
@@ -158,9 +158,9 @@ Public Member Functions
virtual std::pair< std::vector< array >, std::vector< int > > | vmap (const std::vector< array > &inputs, const std::vector< int > &axes) |
| The primitive must know how to vectorize itself across the given axes.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
@@ -173,8 +173,8 @@ Public Member Functions
|
-
-◆ AsStrided()
+
+◆ AsStrided()
@@ -190,12 +190,12 @@ Public Member Functions
|
|
- std::vector< int > | shape, |
+ Shape | shape, |
|
|
- std::vector< size_t > | strides, |
+ Strides | strides, |
|
diff --git a/docs/build/html/classmlx_1_1core_1_1_as_type-members.html b/docs/build/html/classmlx_1_1core_1_1_as_type-members.html
index 428f204c9..037024891 100644
--- a/docs/build/html/classmlx_1_1core_1_1_as_type-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_as_type-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::AsType | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::AsType | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_as_type.html b/docs/build/html/classmlx_1_1core_1_1_as_type.html
index ac1d921cb..43af8c5ac 100644
--- a/docs/build/html/classmlx_1_1core_1_1_as_type.html
+++ b/docs/build/html/classmlx_1_1core_1_1_as_type.html
@@ -127,9 +127,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
- std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+ std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -332,8 +332,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -342,7 +342,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::AsType::output_shapes |
+ std::vector< Shape > mlx::core::AsType::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -358,7 +358,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html
index 34ef1905f..ff5e993be 100644
--- a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html
@@ -110,7 +110,7 @@ $(function(){ initResizable(false); });
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
Or enum value | mlx::core::BitwiseBinary | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::BitwiseBinary | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::BitwiseBinary | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html
index acec87e7a..829eb977b 100644
--- a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html
+++ b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html
@@ -144,9 +144,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -368,8 +368,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -378,7 +378,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::BitwiseBinary::output_shapes |
+ std::vector< Shape > mlx::core::BitwiseBinary::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -394,7 +394,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html
index 20ede52ba..3c4e236df 100644
--- a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html
index 76732bae4..be2921ad0 100644
--- a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html
+++ b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html
@@ -158,9 +158,9 @@ Public Member Functions
virtual std::pair< std::vector< array >, std::vector< int > > | vmap (const std::vector< array > &inputs, const std::vector< int > &axes) |
| The primitive must know how to vectorize itself across the given axes.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html b/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html
index 3fbf3541a..2abdd8ef9 100644
--- a/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html
@@ -94,7 +94,7 @@ $(function(){ initResizable(false); });
This is the complete list of members for mlx::core::Broadcast, including all inherited members.
- Broadcast(Stream stream, const std::vector< int > &shape) | mlx::core::Broadcast | inlineexplicit |
+ Broadcast(Stream stream, const Shape &shape) | mlx::core::Broadcast | inlineexplicit |
device() | mlx::core::Primitive | inline |
eval_cpu(const std::vector< array > &inputs, array &out) override | mlx::core::Broadcast | virtual |
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override | mlx::core::UnaryPrimitive | inlinevirtual |
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_broadcast.html b/docs/build/html/classmlx_1_1core_1_1_broadcast.html
index d9756bff4..76dcf5224 100644
--- a/docs/build/html/classmlx_1_1core_1_1_broadcast.html
+++ b/docs/build/html/classmlx_1_1core_1_1_broadcast.html
@@ -109,8 +109,8 @@ Inheritance diagram for mlx::core::Broadcast:
-
-◆ Broadcast()
+
+◆ Broadcast()
@@ -190,7 +190,7 @@ Public Member Functions
|
|
- const std::vector< int > & | shape ) |
+ const Shape & | shape ) |
diff --git a/docs/build/html/classmlx_1_1core_1_1_ceil-members.html b/docs/build/html/classmlx_1_1core_1_1_ceil-members.html
index 174b5f7f0..1081c0aa0 100644
--- a/docs/build/html/classmlx_1_1core_1_1_ceil-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_ceil-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Ceil | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Ceil | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_ceil.html b/docs/build/html/classmlx_1_1core_1_1_ceil.html
index dbada590b..8b5cfdb35 100644
--- a/docs/build/html/classmlx_1_1core_1_1_ceil.html
+++ b/docs/build/html/classmlx_1_1core_1_1_ceil.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Ceil::output_shapes |
+ std::vector< Shape > mlx::core::Ceil::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_cholesky-members.html b/docs/build/html/classmlx_1_1core_1_1_cholesky-members.html
index 9125afcce..21d347f76 100644
--- a/docs/build/html/classmlx_1_1core_1_1_cholesky-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_cholesky-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_cholesky.html b/docs/build/html/classmlx_1_1core_1_1_cholesky.html
index 9315c92a2..aa973813e 100644
--- a/docs/build/html/classmlx_1_1core_1_1_cholesky.html
+++ b/docs/build/html/classmlx_1_1core_1_1_cholesky.html
@@ -158,9 +158,9 @@ Public Member Functions
virtual bool | is_equivalent (const Primitive &other) const |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_compiled-members.html b/docs/build/html/classmlx_1_1core_1_1_compiled-members.html
index 9386ad7e0..4e1d6330e 100644
--- a/docs/build/html/classmlx_1_1core_1_1_compiled-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_compiled-members.html
@@ -103,7 +103,7 @@ $(function(){ initResizable(false); });
lib_name() const | mlx::core::Compiled | inline |
operator=(const Primitive &other)=delete | mlx::core::Primitive | |
operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Compiled | virtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Compiled | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_compiled.html b/docs/build/html/classmlx_1_1core_1_1_compiled.html
index be9d80ce4..36ee259c0 100644
--- a/docs/build/html/classmlx_1_1core_1_1_compiled.html
+++ b/docs/build/html/classmlx_1_1core_1_1_compiled.html
@@ -124,9 +124,9 @@ Public Member Functions
std::vector< array > | vjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override |
| The vector-Jacobian product.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
void | print (std::ostream &os) override |
| Print the primitive.
|
|
@@ -358,8 +358,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -368,7 +368,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Compiled::output_shapes |
+ std::vector< Shape > mlx::core::Compiled::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -384,7 +384,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html b/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html
index 5ab1f0673..3a7e68413 100644
--- a/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Concatenate | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_concatenate.html b/docs/build/html/classmlx_1_1core_1_1_concatenate.html
index a16807eb7..c919179f3 100644
--- a/docs/build/html/classmlx_1_1core_1_1_concatenate.html
+++ b/docs/build/html/classmlx_1_1core_1_1_concatenate.html
@@ -130,6 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -158,9 +161,6 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
@@ -330,6 +330,36 @@ Public Member Functions
Reimplemented from mlx::core::Primitive.
+
+
+
+◆ output_shapes()
+
+
+
+
+
+
+
+
+ std::vector< Shape > mlx::core::Concatenate::output_shapes |
+ ( |
+ const std::vector< array > & | inputs | ) |
+ |
+
+
+ |
+
+overridevirtual |
+
+
+
+
+ Get the output shapes of the primitive.
+ This is not required to be implemented by derived classes, in which case it will throw.
+
+ Reimplemented from mlx::core::Primitive.
+
diff --git a/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html b/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html
index 577733e4f..f7924d36f 100644
--- a/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Conjugate | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Conjugate | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_conjugate.html b/docs/build/html/classmlx_1_1core_1_1_conjugate.html
index b4f43cb4f..2546f35eb 100644
--- a/docs/build/html/classmlx_1_1core_1_1_conjugate.html
+++ b/docs/build/html/classmlx_1_1core_1_1_conjugate.html
@@ -124,9 +124,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -290,8 +290,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -300,7 +300,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Conjugate::output_shapes |
+ std::vector< Shape > mlx::core::Conjugate::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -316,7 +316,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_contiguous-members.html b/docs/build/html/classmlx_1_1core_1_1_contiguous-members.html
index 49a90d86e..060223ede 100644
--- a/docs/build/html/classmlx_1_1core_1_1_contiguous-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_contiguous-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Contiguous | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Contiguous | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_contiguous.html b/docs/build/html/classmlx_1_1core_1_1_contiguous.html
index 98a3d0d20..cbbab66f1 100644
--- a/docs/build/html/classmlx_1_1core_1_1_contiguous.html
+++ b/docs/build/html/classmlx_1_1core_1_1_contiguous.html
@@ -127,9 +127,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -332,8 +332,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -342,7 +342,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Contiguous::output_shapes |
+ std::vector< Shape > mlx::core::Contiguous::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -358,7 +358,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_convolution-members.html b/docs/build/html/classmlx_1_1core_1_1_convolution-members.html
index 8780055f3..9d7d28745 100644
--- a/docs/build/html/classmlx_1_1core_1_1_convolution-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_convolution-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_convolution.html b/docs/build/html/classmlx_1_1core_1_1_convolution.html
index 81d1be71a..72fd0bdfd 100644
--- a/docs/build/html/classmlx_1_1core_1_1_convolution.html
+++ b/docs/build/html/classmlx_1_1core_1_1_convolution.html
@@ -158,9 +158,9 @@ Public Member Functions
virtual std::pair< std::vector< array >, std::vector< int > > | vmap (const std::vector< array > &inputs, const std::vector< int > &axes) |
| The primitive must know how to vectorize itself across the given axes.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_copy-members.html b/docs/build/html/classmlx_1_1core_1_1_copy-members.html
index 1606aa066..f497b5bb8 100644
--- a/docs/build/html/classmlx_1_1core_1_1_copy-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_copy-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Copy | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Copy | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_copy.html b/docs/build/html/classmlx_1_1core_1_1_copy.html
index 4041537bb..2ae401c26 100644
--- a/docs/build/html/classmlx_1_1core_1_1_copy.html
+++ b/docs/build/html/classmlx_1_1core_1_1_copy.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Copy::output_shapes |
+ std::vector< Shape > mlx::core::Copy::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_cos-members.html b/docs/build/html/classmlx_1_1core_1_1_cos-members.html
index 88b752b14..b6cee517e 100644
--- a/docs/build/html/classmlx_1_1core_1_1_cos-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_cos-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Cos | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Cos | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_cos.html b/docs/build/html/classmlx_1_1core_1_1_cos.html
index 58e722c8d..f19e939c4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_cos.html
+++ b/docs/build/html/classmlx_1_1core_1_1_cos.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Cos::output_shapes |
+ std::vector< Shape > mlx::core::Cos::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_cosh-members.html b/docs/build/html/classmlx_1_1core_1_1_cosh-members.html
index 16260ef99..42a8079a2 100644
--- a/docs/build/html/classmlx_1_1core_1_1_cosh-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_cosh-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Cosh | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Cosh | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_cosh.html b/docs/build/html/classmlx_1_1core_1_1_cosh.html
index e02141c43..a7f2caccb 100644
--- a/docs/build/html/classmlx_1_1core_1_1_cosh.html
+++ b/docs/build/html/classmlx_1_1core_1_1_cosh.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Cosh::output_shapes |
+ std::vector< Shape > mlx::core::Cosh::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_custom_transforms-members.html b/docs/build/html/classmlx_1_1core_1_1_custom_transforms-members.html
index 360e1eae2..219d833d5 100644
--- a/docs/build/html/classmlx_1_1core_1_1_custom_transforms-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_custom_transforms-members.html
@@ -102,7 +102,7 @@ $(function(){ initResizable(false); });
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override | mlx::core::CustomTransforms | virtual |
operator=(const Primitive &other)=delete | mlx::core::Primitive | |
operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_custom_transforms.html b/docs/build/html/classmlx_1_1core_1_1_custom_transforms.html
index a1d55bd8e..34b499238 100644
--- a/docs/build/html/classmlx_1_1core_1_1_custom_transforms.html
+++ b/docs/build/html/classmlx_1_1core_1_1_custom_transforms.html
@@ -139,9 +139,9 @@ Public Member Functions
virtual bool | is_equivalent (const Primitive &other) const |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_depends-members.html b/docs/build/html/classmlx_1_1core_1_1_depends-members.html
index f74ee3a0d..0759ea3c4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_depends-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_depends-members.html
@@ -102,7 +102,7 @@ $(function(){ initResizable(false); });
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) | mlx::core::Primitive | virtual |
operator=(const Primitive &other)=delete | mlx::core::Primitive | |
operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_depends.html b/docs/build/html/classmlx_1_1core_1_1_depends.html
index dcd4ddd18..d509451cd 100644
--- a/docs/build/html/classmlx_1_1core_1_1_depends.html
+++ b/docs/build/html/classmlx_1_1core_1_1_depends.html
@@ -139,9 +139,9 @@ Public Member Functions
virtual bool | is_equivalent (const Primitive &other) const |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html b/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html
index 8b8e2684d..f1df14e3d 100644
--- a/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html
@@ -102,7 +102,7 @@ $(function(){ initResizable(false); });
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override | mlx::core::DivMod | virtual |
operator=(const Primitive &other)=delete | mlx::core::Primitive | |
operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::DivMod | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::DivMod | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_div_mod.html b/docs/build/html/classmlx_1_1core_1_1_div_mod.html
index fbe8e3427..d92944b0b 100644
--- a/docs/build/html/classmlx_1_1core_1_1_div_mod.html
+++ b/docs/build/html/classmlx_1_1core_1_1_div_mod.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| Primitive (Stream stream) |
|
@@ -312,8 +312,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -322,7 +322,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::DivMod::output_shapes |
+ std::vector< Shape > mlx::core::DivMod::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -338,7 +338,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_divide-members.html b/docs/build/html/classmlx_1_1core_1_1_divide-members.html
index 45d34065f..214a0cd11 100644
--- a/docs/build/html/classmlx_1_1core_1_1_divide-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_divide-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Divide | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Divide | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_divide.html b/docs/build/html/classmlx_1_1core_1_1_divide.html
index 070127575..b7ee8ee63 100644
--- a/docs/build/html/classmlx_1_1core_1_1_divide.html
+++ b/docs/build/html/classmlx_1_1core_1_1_divide.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Divide::output_shapes |
+ std::vector< Shape > mlx::core::Divide::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_eigh-members.html b/docs/build/html/classmlx_1_1core_1_1_eigh-members.html
index ecad9877f..c7768d819 100644
--- a/docs/build/html/classmlx_1_1core_1_1_eigh-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_eigh-members.html
@@ -98,11 +98,11 @@ $(function(){ initResizable(false); });
Eigh(Stream stream, std::string uplo, bool compute_eigenvectors) | mlx::core::Eigh | inlineexplicit |
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override | mlx::core::Eigh | virtual |
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override | mlx::core::Eigh | virtual |
- is_equivalent(const Primitive &other) const override | mlx::core::Eigh | inlinevirtual |
+ is_equivalent(const Primitive &other) const override | mlx::core::Eigh | virtual |
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) | mlx::core::Primitive | virtual |
operator=(const Primitive &other)=delete | mlx::core::Primitive | |
operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Eigh | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Eigh | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_eigh.html b/docs/build/html/classmlx_1_1core_1_1_eigh.html
index 4fbbd9537..2bd389b74 100644
--- a/docs/build/html/classmlx_1_1core_1_1_eigh.html
+++ b/docs/build/html/classmlx_1_1core_1_1_eigh.html
@@ -121,9 +121,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -272,7 +272,7 @@ Public Member Functions
|
-inlineoverridevirtual |
+ overridevirtual
@@ -283,8 +283,8 @@ Public Member Functions |
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -293,7 +293,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Eigh::output_shapes |
+ std::vector< Shape > mlx::core::Eigh::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -301,7 +301,7 @@ Public Member Functions
|
-inlineoverridevirtual |
+ overridevirtual
@@ -309,7 +309,7 @@ Public Member Functions |
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
- Reimplemented from mlx::core::Primitive.
+ Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_equal-members.html b/docs/build/html/classmlx_1_1core_1_1_equal-members.html
index 271bf1c3c..16ccc7c09 100644
--- a/docs/build/html/classmlx_1_1core_1_1_equal-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_equal-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Equal | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Equal | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_equal.html b/docs/build/html/classmlx_1_1core_1_1_equal.html
index 3d6c7990e..ab1164fc0 100644
--- a/docs/build/html/classmlx_1_1core_1_1_equal.html
+++ b/docs/build/html/classmlx_1_1core_1_1_equal.html
@@ -127,9 +127,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
void | print (std::ostream &os) override |
| Print the primitive.
|
|
@@ -332,8 +332,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -342,7 +342,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Equal::output_shapes |
+ std::vector< Shape > mlx::core::Equal::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -358,7 +358,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_erf-members.html b/docs/build/html/classmlx_1_1core_1_1_erf-members.html
index 923e8e41d..2913b6186 100644
--- a/docs/build/html/classmlx_1_1core_1_1_erf-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_erf-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Erf | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Erf | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_erf.html b/docs/build/html/classmlx_1_1core_1_1_erf.html
index 3c41200a9..95b797e6b 100644
--- a/docs/build/html/classmlx_1_1core_1_1_erf.html
+++ b/docs/build/html/classmlx_1_1core_1_1_erf.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Erf::output_shapes |
+ std::vector< Shape > mlx::core::Erf::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html b/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html
index 60913ee17..96fa3aa9b 100644
--- a/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::ErfInv | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::ErfInv | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_erf_inv.html b/docs/build/html/classmlx_1_1core_1_1_erf_inv.html
index b5d0b14f9..dca528218 100644
--- a/docs/build/html/classmlx_1_1core_1_1_erf_inv.html
+++ b/docs/build/html/classmlx_1_1core_1_1_erf_inv.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::ErfInv::output_shapes |
+ std::vector< Shape > mlx::core::ErfInv::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_exp-members.html b/docs/build/html/classmlx_1_1core_1_1_exp-members.html
index 8c0a4d8c2..de08e5ea1 100644
--- a/docs/build/html/classmlx_1_1core_1_1_exp-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_exp-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Exp | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Exp | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_exp.html b/docs/build/html/classmlx_1_1core_1_1_exp.html
index d52e9d693..a9ea24023 100644
--- a/docs/build/html/classmlx_1_1core_1_1_exp.html
+++ b/docs/build/html/classmlx_1_1core_1_1_exp.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Exp::output_shapes |
+ std::vector< Shape > mlx::core::Exp::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_expm1-members.html b/docs/build/html/classmlx_1_1core_1_1_expm1-members.html
index c1093fe36..254f446a0 100644
--- a/docs/build/html/classmlx_1_1core_1_1_expm1-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_expm1-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Expm1 | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Expm1 | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_expm1.html b/docs/build/html/classmlx_1_1core_1_1_expm1.html
index ff23e46da..1512dbe27 100644
--- a/docs/build/html/classmlx_1_1core_1_1_expm1.html
+++ b/docs/build/html/classmlx_1_1core_1_1_expm1.html
@@ -127,9 +127,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -299,8 +299,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -309,7 +309,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Expm1::output_shapes |
+ std::vector< Shape > mlx::core::Expm1::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -325,7 +325,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html b/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html
index 07d53ba7b..15d37e8d7 100644
--- a/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_f_f_t.html b/docs/build/html/classmlx_1_1core_1_1_f_f_t.html
index 15cbb62c2..2e9f2db09 100644
--- a/docs/build/html/classmlx_1_1core_1_1_f_f_t.html
+++ b/docs/build/html/classmlx_1_1core_1_1_f_f_t.html
@@ -158,9 +158,9 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_floor-members.html b/docs/build/html/classmlx_1_1core_1_1_floor-members.html
index 1c340e0d8..597575b8b 100644
--- a/docs/build/html/classmlx_1_1core_1_1_floor-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_floor-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Floor | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Floor | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_floor.html b/docs/build/html/classmlx_1_1core_1_1_floor.html
index 7a0f3d3e4..158b57a6e 100644
--- a/docs/build/html/classmlx_1_1core_1_1_floor.html
+++ b/docs/build/html/classmlx_1_1core_1_1_floor.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Floor::output_shapes |
+ std::vector< Shape > mlx::core::Floor::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_full-members.html b/docs/build/html/classmlx_1_1core_1_1_full-members.html
index 2613f2546..2135a5e95 100644
--- a/docs/build/html/classmlx_1_1core_1_1_full-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_full-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_full.html b/docs/build/html/classmlx_1_1core_1_1_full.html
index ca2f7d7dd..3b516acaf 100644
--- a/docs/build/html/classmlx_1_1core_1_1_full.html
+++ b/docs/build/html/classmlx_1_1core_1_1_full.html
@@ -158,9 +158,9 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_gather-members.html b/docs/build/html/classmlx_1_1core_1_1_gather-members.html
index 85513cc97..e6733eee4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_gather-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_gather-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Gather | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_gather.html b/docs/build/html/classmlx_1_1core_1_1_gather.html
index 503359e7b..d5cb92af5 100644
--- a/docs/build/html/classmlx_1_1core_1_1_gather.html
+++ b/docs/build/html/classmlx_1_1core_1_1_gather.html
@@ -130,6 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -158,9 +161,6 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
@@ -335,6 +335,36 @@ Public Member Functions
Reimplemented from mlx::core::Primitive.
+
+
+
+◆ output_shapes()
+
+
+
+
+
+
+
+
+ std::vector< Shape > mlx::core::Gather::output_shapes |
+ ( |
+ const std::vector< array > & | inputs | ) |
+ |
+
+
+ |
+
+overridevirtual |
+
+
+
+
+ Get the output shapes of the primitive.
+ This is not required to be implemented by derived classes, in which case it will throw.
+
+ Reimplemented from mlx::core::Primitive.
+
diff --git a/docs/build/html/classmlx_1_1core_1_1_gather_m_m-members.html b/docs/build/html/classmlx_1_1core_1_1_gather_m_m-members.html
index c33e774af..2b8f921d5 100644
--- a/docs/build/html/classmlx_1_1core_1_1_gather_m_m-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_gather_m_m-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_gather_m_m.html b/docs/build/html/classmlx_1_1core_1_1_gather_m_m.html
index bc46d8c13..98227cd62 100644
--- a/docs/build/html/classmlx_1_1core_1_1_gather_m_m.html
+++ b/docs/build/html/classmlx_1_1core_1_1_gather_m_m.html
@@ -158,9 +158,9 @@ Public Member Functions
virtual std::pair< std::vector< array >, std::vector< int > > | vmap (const std::vector< array > &inputs, const std::vector< int > &axes) |
| The primitive must know how to vectorize itself across the given axes.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_gather_q_m_m-members.html b/docs/build/html/classmlx_1_1core_1_1_gather_q_m_m-members.html
index 0d7fe677d..e14dadcd3 100644
--- a/docs/build/html/classmlx_1_1core_1_1_gather_q_m_m-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_gather_q_m_m-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_gather_q_m_m.html b/docs/build/html/classmlx_1_1core_1_1_gather_q_m_m.html
index c278e2aa5..f4ce14c36 100644
--- a/docs/build/html/classmlx_1_1core_1_1_gather_q_m_m.html
+++ b/docs/build/html/classmlx_1_1core_1_1_gather_q_m_m.html
@@ -158,9 +158,9 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_greater-members.html b/docs/build/html/classmlx_1_1core_1_1_greater-members.html
index 7aeb96b02..393efc68c 100644
--- a/docs/build/html/classmlx_1_1core_1_1_greater-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_greater-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Greater | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Greater | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_greater.html b/docs/build/html/classmlx_1_1core_1_1_greater.html
index 53c15ff9d..416043993 100644
--- a/docs/build/html/classmlx_1_1core_1_1_greater.html
+++ b/docs/build/html/classmlx_1_1core_1_1_greater.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Greater::output_shapes |
+ std::vector< Shape > mlx::core::Greater::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_greater_equal-members.html b/docs/build/html/classmlx_1_1core_1_1_greater_equal-members.html
index 37e2e68dc..7a53eeac2 100644
--- a/docs/build/html/classmlx_1_1core_1_1_greater_equal-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_greater_equal-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::GreaterEqual | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::GreaterEqual | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_greater_equal.html b/docs/build/html/classmlx_1_1core_1_1_greater_equal.html
index 48e105066..91a285a34 100644
--- a/docs/build/html/classmlx_1_1core_1_1_greater_equal.html
+++ b/docs/build/html/classmlx_1_1core_1_1_greater_equal.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::GreaterEqual::output_shapes |
+ std::vector< Shape > mlx::core::GreaterEqual::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_hadamard-members.html b/docs/build/html/classmlx_1_1core_1_1_hadamard-members.html
index 9e9152c3e..7ea5ac4a9 100644
--- a/docs/build/html/classmlx_1_1core_1_1_hadamard-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_hadamard-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Hadamard | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Hadamard | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_hadamard.html b/docs/build/html/classmlx_1_1core_1_1_hadamard.html
index c3c7d602c..862612c8d 100644
--- a/docs/build/html/classmlx_1_1core_1_1_hadamard.html
+++ b/docs/build/html/classmlx_1_1core_1_1_hadamard.html
@@ -127,9 +127,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -332,8 +332,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -342,7 +342,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Hadamard::output_shapes |
+ std::vector< Shape > mlx::core::Hadamard::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -358,7 +358,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_imag-members.html b/docs/build/html/classmlx_1_1core_1_1_imag-members.html
index b03925f76..faa681b6b 100644
--- a/docs/build/html/classmlx_1_1core_1_1_imag-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_imag-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Imag | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Imag | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_imag.html b/docs/build/html/classmlx_1_1core_1_1_imag.html
index 4b9c3bd78..cece5f6fa 100644
--- a/docs/build/html/classmlx_1_1core_1_1_imag.html
+++ b/docs/build/html/classmlx_1_1core_1_1_imag.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Imag::output_shapes |
+ std::vector< Shape > mlx::core::Imag::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_inverse-members.html b/docs/build/html/classmlx_1_1core_1_1_inverse-members.html
index 38102ac47..95a0d738b 100644
--- a/docs/build/html/classmlx_1_1core_1_1_inverse-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_inverse-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_inverse.html b/docs/build/html/classmlx_1_1core_1_1_inverse.html
index e7e514433..e1653bf32 100644
--- a/docs/build/html/classmlx_1_1core_1_1_inverse.html
+++ b/docs/build/html/classmlx_1_1core_1_1_inverse.html
@@ -158,9 +158,9 @@ Public Member Functions
virtual bool | is_equivalent (const Primitive &other) const |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_less-members.html b/docs/build/html/classmlx_1_1core_1_1_less-members.html
index 10222f842..04fcd75f4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_less-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_less-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Less | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Less | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_less.html b/docs/build/html/classmlx_1_1core_1_1_less.html
index e8dea1bca..eed868815 100644
--- a/docs/build/html/classmlx_1_1core_1_1_less.html
+++ b/docs/build/html/classmlx_1_1core_1_1_less.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Less::output_shapes |
+ std::vector< Shape > mlx::core::Less::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_less_equal-members.html b/docs/build/html/classmlx_1_1core_1_1_less_equal-members.html
index 166da95ed..819db5b3e 100644
--- a/docs/build/html/classmlx_1_1core_1_1_less_equal-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_less_equal-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::LessEqual | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::LessEqual | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_less_equal.html b/docs/build/html/classmlx_1_1core_1_1_less_equal.html
index 4cab4e85b..4db799314 100644
--- a/docs/build/html/classmlx_1_1core_1_1_less_equal.html
+++ b/docs/build/html/classmlx_1_1core_1_1_less_equal.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::LessEqual::output_shapes |
+ std::vector< Shape > mlx::core::LessEqual::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_load-members.html b/docs/build/html/classmlx_1_1core_1_1_load-members.html
index 2c3523da0..6732a8c53 100644
--- a/docs/build/html/classmlx_1_1core_1_1_load-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_load-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_load.html b/docs/build/html/classmlx_1_1core_1_1_load.html
index ba626bff2..0f7acf0f1 100644
--- a/docs/build/html/classmlx_1_1core_1_1_load.html
+++ b/docs/build/html/classmlx_1_1core_1_1_load.html
@@ -158,9 +158,9 @@ Public Member Functions
virtual bool | is_equivalent (const Primitive &other) const |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_log-members.html b/docs/build/html/classmlx_1_1core_1_1_log-members.html
index a33c29be1..cf2377ed5 100644
--- a/docs/build/html/classmlx_1_1core_1_1_log-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_log-members.html
@@ -108,7 +108,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Log | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Log | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_log.html b/docs/build/html/classmlx_1_1core_1_1_log.html
index ede6be216..8363184f5 100644
--- a/docs/build/html/classmlx_1_1core_1_1_log.html
+++ b/docs/build/html/classmlx_1_1core_1_1_log.html
@@ -136,9 +136,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
void | print (std::ostream &os) override |
| Print the primitive.
|
|
@@ -361,8 +361,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -371,7 +371,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Log::output_shapes |
+ std::vector< Shape > mlx::core::Log::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -387,7 +387,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_log1p-members.html b/docs/build/html/classmlx_1_1core_1_1_log1p-members.html
index 01f0ee214..a24e6f8f9 100644
--- a/docs/build/html/classmlx_1_1core_1_1_log1p-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_log1p-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Log1p | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Log1p | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_log1p.html b/docs/build/html/classmlx_1_1core_1_1_log1p.html
index d9933c906..5421330ff 100644
--- a/docs/build/html/classmlx_1_1core_1_1_log1p.html
+++ b/docs/build/html/classmlx_1_1core_1_1_log1p.html
@@ -127,9 +127,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -299,8 +299,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -309,7 +309,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Log1p::output_shapes |
+ std::vector< Shape > mlx::core::Log1p::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -325,7 +325,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html b/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html
index 78c33876a..33e2637bc 100644
--- a/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::LogAddExp | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::LogAddExp | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html b/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html
index 383af1b6d..b1e92fbf4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html
+++ b/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::LogAddExp::output_shapes |
+ std::vector< Shape > mlx::core::LogAddExp::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html b/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html
index 777a2db7d..dece0f0b4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::LogicalAnd | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::LogicalAnd | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_and.html b/docs/build/html/classmlx_1_1core_1_1_logical_and.html
index 4023bef47..b7878e49c 100644
--- a/docs/build/html/classmlx_1_1core_1_1_logical_and.html
+++ b/docs/build/html/classmlx_1_1core_1_1_logical_and.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::LogicalAnd::output_shapes |
+ std::vector< Shape > mlx::core::LogicalAnd::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html b/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html
index 87c9c7c15..88a5f9b08 100644
--- a/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::LogicalNot | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::LogicalNot | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_not.html b/docs/build/html/classmlx_1_1core_1_1_logical_not.html
index 5c2465743..2390bf350 100644
--- a/docs/build/html/classmlx_1_1core_1_1_logical_not.html
+++ b/docs/build/html/classmlx_1_1core_1_1_logical_not.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::LogicalNot::output_shapes |
+ std::vector< Shape > mlx::core::LogicalNot::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_or-members.html b/docs/build/html/classmlx_1_1core_1_1_logical_or-members.html
index f5042d369..6d808d6e3 100644
--- a/docs/build/html/classmlx_1_1core_1_1_logical_or-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_logical_or-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::LogicalOr | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::LogicalOr | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_or.html b/docs/build/html/classmlx_1_1core_1_1_logical_or.html
index 760504e4a..c931a5265 100644
--- a/docs/build/html/classmlx_1_1core_1_1_logical_or.html
+++ b/docs/build/html/classmlx_1_1core_1_1_logical_or.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::LogicalOr::output_shapes |
+ std::vector< Shape > mlx::core::LogicalOr::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_matmul-members.html b/docs/build/html/classmlx_1_1core_1_1_matmul-members.html
index 28431d211..3fccc14b4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_matmul-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_matmul-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Matmul | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_matmul.html b/docs/build/html/classmlx_1_1core_1_1_matmul.html
index f879a5d59..9c0590ed8 100644
--- a/docs/build/html/classmlx_1_1core_1_1_matmul.html
+++ b/docs/build/html/classmlx_1_1core_1_1_matmul.html
@@ -127,6 +127,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -158,9 +161,6 @@ Public Member Functions
virtual std::vector< array > | jvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) |
| The Jacobian-vector product.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
@@ -288,6 +288,36 @@ Public Member Functions
Reimplemented from mlx::core::Primitive.
+
+
+
+◆ output_shapes()
+
+
+
+
+
+
+
+
+ std::vector< Shape > mlx::core::Matmul::output_shapes |
+ ( |
+ const std::vector< array > & | inputs | ) |
+ |
+
+
+ |
+
+overridevirtual |
+
+
+
+
+ Get the output shapes of the primitive.
+ This is not required to be implemented by derived classes, in which case it will throw.
+
+ Reimplemented from mlx::core::Primitive.
+
diff --git a/docs/build/html/classmlx_1_1core_1_1_maximum-members.html b/docs/build/html/classmlx_1_1core_1_1_maximum-members.html
index 9f0c58ef8..4016565ff 100644
--- a/docs/build/html/classmlx_1_1core_1_1_maximum-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_maximum-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Maximum | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Maximum | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_maximum.html b/docs/build/html/classmlx_1_1core_1_1_maximum.html
index 198b947ca..3171731aa 100644
--- a/docs/build/html/classmlx_1_1core_1_1_maximum.html
+++ b/docs/build/html/classmlx_1_1core_1_1_maximum.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Maximum::output_shapes |
+ std::vector< Shape > mlx::core::Maximum::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_minimum-members.html b/docs/build/html/classmlx_1_1core_1_1_minimum-members.html
index 20a161177..5dbe55d50 100644
--- a/docs/build/html/classmlx_1_1core_1_1_minimum-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_minimum-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Minimum | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Minimum | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_minimum.html b/docs/build/html/classmlx_1_1core_1_1_minimum.html
index ae43d36e0..115df15ca 100644
--- a/docs/build/html/classmlx_1_1core_1_1_minimum.html
+++ b/docs/build/html/classmlx_1_1core_1_1_minimum.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Minimum::output_shapes |
+ std::vector< Shape > mlx::core::Minimum::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_multiply-members.html b/docs/build/html/classmlx_1_1core_1_1_multiply-members.html
index 0e75527c1..3bd012ccb 100644
--- a/docs/build/html/classmlx_1_1core_1_1_multiply-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_multiply-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Multiply | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Multiply | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_multiply.html b/docs/build/html/classmlx_1_1core_1_1_multiply.html
index 840332607..8463dd354 100644
--- a/docs/build/html/classmlx_1_1core_1_1_multiply.html
+++ b/docs/build/html/classmlx_1_1core_1_1_multiply.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Multiply::output_shapes |
+ std::vector< Shape > mlx::core::Multiply::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_negative-members.html b/docs/build/html/classmlx_1_1core_1_1_negative-members.html
index bf3207169..439342b71 100644
--- a/docs/build/html/classmlx_1_1core_1_1_negative-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_negative-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Negative | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Negative | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_negative.html b/docs/build/html/classmlx_1_1core_1_1_negative.html
index d1b9c8dff..5ccf45862 100644
--- a/docs/build/html/classmlx_1_1core_1_1_negative.html
+++ b/docs/build/html/classmlx_1_1core_1_1_negative.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Negative::output_shapes |
+ std::vector< Shape > mlx::core::Negative::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html b/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html
index fe17b758f..fde404ae4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::NotEqual | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::NotEqual | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_not_equal.html b/docs/build/html/classmlx_1_1core_1_1_not_equal.html
index 616169afd..d7e087adf 100644
--- a/docs/build/html/classmlx_1_1core_1_1_not_equal.html
+++ b/docs/build/html/classmlx_1_1core_1_1_not_equal.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::NotEqual::output_shapes |
+ std::vector< Shape > mlx::core::NotEqual::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_number_of_elements-members.html b/docs/build/html/classmlx_1_1core_1_1_number_of_elements-members.html
index 5d72d22da..dd47e78c7 100644
--- a/docs/build/html/classmlx_1_1core_1_1_number_of_elements-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_number_of_elements-members.html
@@ -106,7 +106,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::NumberOfElements | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::NumberOfElements | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html b/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html
index f0ee82aa3..c61a1ede1 100644
--- a/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html
+++ b/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html
@@ -124,9 +124,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -304,8 +304,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -314,7 +314,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::NumberOfElements::output_shapes |
+ std::vector< Shape > mlx::core::NumberOfElements::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -330,7 +330,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_pad-members.html b/docs/build/html/classmlx_1_1core_1_1_pad-members.html
index 45d80c55b..5c56d4d49 100644
--- a/docs/build/html/classmlx_1_1core_1_1_pad-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_pad-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size) | mlx::core::Pad | inlineexplicit |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_pad.html b/docs/build/html/classmlx_1_1core_1_1_pad.html
index dfb4cdaf1..1497128f7 100644
--- a/docs/build/html/classmlx_1_1core_1_1_pad.html
+++ b/docs/build/html/classmlx_1_1core_1_1_pad.html
@@ -158,9 +158,9 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_partition-members.html b/docs/build/html/classmlx_1_1core_1_1_partition-members.html
index 4ba991643..47532df95 100644
--- a/docs/build/html/classmlx_1_1core_1_1_partition-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_partition-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Partition | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Partition | inlinevirtual |
Partition(Stream stream, int kth, int axis) | mlx::core::Partition | inlineexplicit |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_partition.html b/docs/build/html/classmlx_1_1core_1_1_partition.html
index 3740f0193..3c7715ab6 100644
--- a/docs/build/html/classmlx_1_1core_1_1_partition.html
+++ b/docs/build/html/classmlx_1_1core_1_1_partition.html
@@ -127,9 +127,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -337,8 +337,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -347,7 +347,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Partition::output_shapes |
+ std::vector< Shape > mlx::core::Partition::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -363,7 +363,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_power-members.html b/docs/build/html/classmlx_1_1core_1_1_power-members.html
index d51e93c3d..b24f7d182 100644
--- a/docs/build/html/classmlx_1_1core_1_1_power-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_power-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Power | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Power | inlinevirtual |
Power(Stream stream) | mlx::core::Power | inlineexplicit |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_power.html b/docs/build/html/classmlx_1_1core_1_1_power.html
index 0df30f7c0..f998a6261 100644
--- a/docs/build/html/classmlx_1_1core_1_1_power.html
+++ b/docs/build/html/classmlx_1_1core_1_1_power.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Power::output_shapes |
+ std::vector< Shape > mlx::core::Power::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_primitive-members.html b/docs/build/html/classmlx_1_1core_1_1_primitive-members.html
index 590e519a2..75bb28941 100644
--- a/docs/build/html/classmlx_1_1core_1_1_primitive-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_primitive-members.html
@@ -101,7 +101,7 @@ $(function(){ initResizable(false); });
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) | mlx::core::Primitive | virtual |
operator=(const Primitive &other)=delete | mlx::core::Primitive | |
operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_primitive.html b/docs/build/html/classmlx_1_1core_1_1_primitive.html
index 0a127787f..1c9b786bb 100644
--- a/docs/build/html/classmlx_1_1core_1_1_primitive.html
+++ b/docs/build/html/classmlx_1_1core_1_1_primitive.html
@@ -147,9 +147,9 @@ Public Member Functions
virtual bool | is_equivalent (const Primitive &other) const |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
@@ -380,7 +380,7 @@ Public Member Functions
Equivalence check defaults to false unless overridden by the primitive.
-Reimplemented in mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan2, mlx::core::ArcTan, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsStrided, mlx::core::AsType, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Contiguous, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Eigh, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::fast::ScaledDotProductAttention, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::GatherMM, mlx::core::GatherQMM, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Hadamard, mlx::core::Imag, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log, mlx::core::LogAddExp, mlx::core::LogicalAnd, mlx::core::LogicalNot, mlx::core::LogicalOr, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Real, mlx::core::Reduce, mlx::core::Remainder, mlx::core::Reshape, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Select, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Sqrt, mlx::core::Square, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Transpose, mlx::core::Uniform, and mlx::core::View.
+Reimplemented in mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan2, mlx::core::ArcTan, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsStrided, mlx::core::AsType, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Contiguous, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Eigh, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::fast::AffineQuantize, mlx::core::fast::ScaledDotProductAttention, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::GatherMM, mlx::core::GatherQMM, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Hadamard, mlx::core::Imag, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log, mlx::core::LogAddExp, mlx::core::LogicalAnd, mlx::core::LogicalNot, mlx::core::LogicalOr, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Real, mlx::core::Reduce, mlx::core::Remainder, mlx::core::Reshape, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Select, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Sqrt, mlx::core::Square, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Transpose, mlx::core::Uniform, and mlx::core::View.
@@ -472,8 +472,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -482,7 +482,7 @@ Public Member Functions
- virtual std::vector< std::vector< int > > mlx::core::Primitive::output_shapes |
+ virtual std::vector< Shape > mlx::core::Primitive::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -498,7 +498,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented in mlx::core::Abs, mlx::core::Add, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan2, mlx::core::ArcTan, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::BitwiseBinary, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Conjugate, mlx::core::Contiguous, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Eigh, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::Floor, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Hadamard, mlx::core::Imag, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log1p, mlx::core::Log, mlx::core::LogAddExp, mlx::core::LogicalAnd, mlx::core::LogicalNot, mlx::core::LogicalOr, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Partition, mlx::core::Power, mlx::core::Real, mlx::core::Reduce, mlx::core::Remainder, mlx::core::Round, mlx::core::Select, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Softmax, mlx::core::Sort, mlx::core::Sqrt, mlx::core::Square, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, and mlx::core::Tanh.
+Reimplemented in mlx::core::Abs, mlx::core::Add, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan2, mlx::core::ArcTan, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::BitwiseBinary, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Contiguous, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Eigh, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::fast::AffineQuantize, mlx::core::Floor, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Hadamard, mlx::core::Imag, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log1p, mlx::core::Log, mlx::core::LogAddExp, mlx::core::LogicalAnd, mlx::core::LogicalNot, mlx::core::LogicalOr, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::Real, mlx::core::Reduce, mlx::core::Remainder, mlx::core::Round, mlx::core::Select, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Softmax, mlx::core::Sort, mlx::core::Sqrt, mlx::core::Square, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, and mlx::core::Transpose.
diff --git a/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html b/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html
index 99dc16a0c..61ffc3c0c 100644
--- a/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html
@@ -101,7 +101,7 @@ $(function(){ initResizable(false); });
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) | mlx::core::Primitive | virtual |
operator=(const Primitive &other)=delete | mlx::core::Primitive | |
operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_q_r_f.html b/docs/build/html/classmlx_1_1core_1_1_q_r_f.html
index 7216ece43..305e43098 100644
--- a/docs/build/html/classmlx_1_1core_1_1_q_r_f.html
+++ b/docs/build/html/classmlx_1_1core_1_1_q_r_f.html
@@ -139,9 +139,9 @@ Public Member Functions
virtual bool | is_equivalent (const Primitive &other) const |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_quantized_matmul-members.html b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul-members.html
index 437c22b9c..2a9dcf5e9 100644
--- a/docs/build/html/classmlx_1_1core_1_1_quantized_matmul-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::QuantizedMatmul | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html
index be86e4e1a..4f158d206 100644
--- a/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html
+++ b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html
@@ -130,6 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -158,9 +161,6 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
@@ -340,6 +340,36 @@ Public Member Functions
Reimplemented from mlx::core::Primitive.
+
+
+
+◆ output_shapes()
+
+
+
+
+
+
+
+
+ std::vector< Shape > mlx::core::QuantizedMatmul::output_shapes |
+ ( |
+ const std::vector< array > & | inputs | ) |
+ |
+
+
+ |
+
+overridevirtual |
+
+
+
+
+ Get the output shapes of the primitive.
+ This is not required to be implemented by derived classes, in which case it will throw.
+
+ Reimplemented from mlx::core::Primitive.
+
diff --git a/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html b/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html
index e945f90ef..96ef570f6 100644
--- a/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html
@@ -105,12 +105,12 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
print(std::ostream &os) override | mlx::core::RandomBits | inlinevirtual |
- RandomBits(Stream stream, const std::vector< int > &shape, int width) | mlx::core::RandomBits | inlineexplicit |
+ RandomBits(Stream stream, const Shape &shape, int width) | mlx::core::RandomBits | inlineexplicit |
stream() | mlx::core::Primitive | inline |
UnaryPrimitive(Stream stream) | mlx::core::UnaryPrimitive | inlineexplicit |
UnaryPrimitive(const UnaryPrimitive &other)=delete | mlx::core::UnaryPrimitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_random_bits.html b/docs/build/html/classmlx_1_1core_1_1_random_bits.html
index ccc19d1f2..44356a2d6 100644
--- a/docs/build/html/classmlx_1_1core_1_1_random_bits.html
+++ b/docs/build/html/classmlx_1_1core_1_1_random_bits.html
@@ -109,8 +109,8 @@ Inheritance diagram for mlx::core::RandomBits:
|
- | RandomBits (Stream stream, const std::vector< int > &shape, int width) |
- |
+ | RandomBits (Stream stream, const Shape &shape, int width) |
+ |
void | eval_cpu (const std::vector< array > &inputs, array &out) override |
|
void | eval_gpu (const std::vector< array > &inputs, array &out) override |
@@ -158,9 +158,9 @@ Public Member Functions
virtual std::vector< array > | vjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) |
| The vector-Jacobian product.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
@@ -173,8 +173,8 @@ Public Member Functions
|
-
-◆ RandomBits()
+
+◆ RandomBits()
@@ -190,7 +190,7 @@ Public Member Functions
|
|
- const std::vector< int > & | shape, |
+ const Shape & | shape, |
|
diff --git a/docs/build/html/classmlx_1_1core_1_1_real-members.html b/docs/build/html/classmlx_1_1core_1_1_real-members.html
index 54ddb3223..292ea563e 100644
--- a/docs/build/html/classmlx_1_1core_1_1_real-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_real-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Real | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Real | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_real.html b/docs/build/html/classmlx_1_1core_1_1_real.html
index f606b5fa0..80bc40945 100644
--- a/docs/build/html/classmlx_1_1core_1_1_real.html
+++ b/docs/build/html/classmlx_1_1core_1_1_real.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
- std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+ std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Real::output_shapes |
+ std::vector< Shape > mlx::core::Real::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_reduce-members.html b/docs/build/html/classmlx_1_1core_1_1_reduce-members.html
index 218805a22..ab8f9908c 100644
--- a/docs/build/html/classmlx_1_1core_1_1_reduce-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_reduce-members.html
@@ -109,7 +109,7 @@ $(function(){ initResizable(false); });
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
Or enum value | mlx::core::Reduce | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Reduce | virtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Reduce | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_reduce.html b/docs/build/html/classmlx_1_1core_1_1_reduce.html
index efbd9a2b4..89103f3cf 100644
--- a/docs/build/html/classmlx_1_1core_1_1_reduce.html
+++ b/docs/build/html/classmlx_1_1core_1_1_reduce.html
@@ -136,9 +136,9 @@ Public Member Functions
std::vector< array > | vjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override |
| The vector-Jacobian product.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
void | print (std::ostream &os) override |
| Print the primitive.
|
|
@@ -337,8 +337,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -347,7 +347,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Reduce::output_shapes |
+ std::vector< Shape > mlx::core::Reduce::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -363,7 +363,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_remainder-members.html b/docs/build/html/classmlx_1_1core_1_1_remainder-members.html
index 70ae04508..0f4e998d3 100644
--- a/docs/build/html/classmlx_1_1core_1_1_remainder-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_remainder-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Remainder | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Remainder | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_remainder.html b/docs/build/html/classmlx_1_1core_1_1_remainder.html
index 958cf55e2..300ebb9a6 100644
--- a/docs/build/html/classmlx_1_1core_1_1_remainder.html
+++ b/docs/build/html/classmlx_1_1core_1_1_remainder.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Remainder::output_shapes |
+ std::vector< Shape > mlx::core::Remainder::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_reshape-members.html b/docs/build/html/classmlx_1_1core_1_1_reshape-members.html
index bc66daf6e..b33864bd7 100644
--- a/docs/build/html/classmlx_1_1core_1_1_reshape-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_reshape-members.html
@@ -105,12 +105,12 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
print(std::ostream &os) override | mlx::core::Reshape | inlinevirtual |
- Reshape(Stream stream, const std::vector< int > &shape) | mlx::core::Reshape | inlineexplicit |
+ Reshape(Stream stream, const Shape &shape) | mlx::core::Reshape | inlineexplicit |
stream() | mlx::core::Primitive | inline |
UnaryPrimitive(Stream stream) | mlx::core::UnaryPrimitive | inlineexplicit |
UnaryPrimitive(const UnaryPrimitive &other)=delete | mlx::core::UnaryPrimitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_reshape.html b/docs/build/html/classmlx_1_1core_1_1_reshape.html
index 1267c423b..78e332e23 100644
--- a/docs/build/html/classmlx_1_1core_1_1_reshape.html
+++ b/docs/build/html/classmlx_1_1core_1_1_reshape.html
@@ -109,8 +109,8 @@ Inheritance diagram for mlx::core::Reshape:
-
-◆ Reshape()
+
+◆ Reshape()
@@ -190,7 +190,7 @@ Public Member Functions
|
|
- const std::vector< int > & | shape ) |
+ const Shape & | shape ) |
|
diff --git a/docs/build/html/classmlx_1_1core_1_1_round-members.html b/docs/build/html/classmlx_1_1core_1_1_round-members.html
index 10158de60..b4769bbf0 100644
--- a/docs/build/html/classmlx_1_1core_1_1_round-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_round-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Round | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Round | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_round.html b/docs/build/html/classmlx_1_1core_1_1_round.html
index a20ccca09..18f0b0290 100644
--- a/docs/build/html/classmlx_1_1core_1_1_round.html
+++ b/docs/build/html/classmlx_1_1core_1_1_round.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
- std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+ std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Round::output_shapes |
+ std::vector< Shape > mlx::core::Round::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html b/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html
index 68383862c..801c6d5bf 100644
--- a/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html
@@ -101,7 +101,7 @@ $(function(){ initResizable(false); });
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) | mlx::core::Primitive | virtual |
operator=(const Primitive &other)=delete | mlx::core::Primitive | |
operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_s_v_d.html b/docs/build/html/classmlx_1_1core_1_1_s_v_d.html
index de5c20b04..fb2a5246d 100644
--- a/docs/build/html/classmlx_1_1core_1_1_s_v_d.html
+++ b/docs/build/html/classmlx_1_1core_1_1_s_v_d.html
@@ -139,9 +139,9 @@ Public Member Functions
virtual bool | is_equivalent (const Primitive &other) const |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_scan-members.html b/docs/build/html/classmlx_1_1core_1_1_scan-members.html
index a503f7add..76d2c4bef 100644
--- a/docs/build/html/classmlx_1_1core_1_1_scan-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_scan-members.html
@@ -107,7 +107,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_scan.html b/docs/build/html/classmlx_1_1core_1_1_scan.html
index 258337007..6e6d0ac4a 100644
--- a/docs/build/html/classmlx_1_1core_1_1_scan.html
+++ b/docs/build/html/classmlx_1_1core_1_1_scan.html
@@ -168,9 +168,9 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_scatter-members.html b/docs/build/html/classmlx_1_1core_1_1_scatter-members.html
index db0486125..219a54b35 100644
--- a/docs/build/html/classmlx_1_1core_1_1_scatter-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_scatter-members.html
@@ -108,7 +108,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_scatter.html b/docs/build/html/classmlx_1_1core_1_1_scatter.html
index e1afd50b8..4ac1f0e21 100644
--- a/docs/build/html/classmlx_1_1core_1_1_scatter.html
+++ b/docs/build/html/classmlx_1_1core_1_1_scatter.html
@@ -172,9 +172,9 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_select-members.html b/docs/build/html/classmlx_1_1core_1_1_select-members.html
index 3c2bf40bb..0ed101f42 100644
--- a/docs/build/html/classmlx_1_1core_1_1_select-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_select-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Select | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Select | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_select.html b/docs/build/html/classmlx_1_1core_1_1_select.html
index 7716071ed..e956423f4 100644
--- a/docs/build/html/classmlx_1_1core_1_1_select.html
+++ b/docs/build/html/classmlx_1_1core_1_1_select.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Select::output_shapes |
+ std::vector< Shape > mlx::core::Select::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html b/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html
index afa786550..a77284f7d 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Sigmoid | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Sigmoid | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_sigmoid.html b/docs/build/html/classmlx_1_1core_1_1_sigmoid.html
index 36e71c356..a94b87126 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sigmoid.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sigmoid.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Sigmoid::output_shapes |
+ std::vector< Shape > mlx::core::Sigmoid::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_sign-members.html b/docs/build/html/classmlx_1_1core_1_1_sign-members.html
index 1aa12f83a..fc123c1f1 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sign-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sign-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Sign | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Sign | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_sign.html b/docs/build/html/classmlx_1_1core_1_1_sign.html
index 337de5e62..48ab7a6ef 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sign.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sign.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Sign::output_shapes |
+ std::vector< Shape > mlx::core::Sign::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_sin-members.html b/docs/build/html/classmlx_1_1core_1_1_sin-members.html
index f92e56387..d48713e3e 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sin-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sin-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Sin | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Sin | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_sin.html b/docs/build/html/classmlx_1_1core_1_1_sin.html
index 3cbd09fb3..8e8fb2776 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sin.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sin.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Sin::output_shapes |
+ std::vector< Shape > mlx::core::Sin::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_sinh-members.html b/docs/build/html/classmlx_1_1core_1_1_sinh-members.html
index ab735256e..8e10feb48 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sinh-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sinh-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Sinh | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Sinh | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_sinh.html b/docs/build/html/classmlx_1_1core_1_1_sinh.html
index 3fc11e1c1..8e73177b7 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sinh.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sinh.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Sinh::output_shapes |
+ std::vector< Shape > mlx::core::Sinh::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_slice-members.html b/docs/build/html/classmlx_1_1core_1_1_slice-members.html
index 904efc554..66386048f 100644
--- a/docs/build/html/classmlx_1_1core_1_1_slice-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_slice-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_slice.html b/docs/build/html/classmlx_1_1core_1_1_slice.html
index 87ba23ce8..995205fb5 100644
--- a/docs/build/html/classmlx_1_1core_1_1_slice.html
+++ b/docs/build/html/classmlx_1_1core_1_1_slice.html
@@ -158,9 +158,9 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html b/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html
index 47ed3f8a8..7e1c8ac74 100644
--- a/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_slice_update.html b/docs/build/html/classmlx_1_1core_1_1_slice_update.html
index e9ccff1fb..5397cc5b3 100644
--- a/docs/build/html/classmlx_1_1core_1_1_slice_update.html
+++ b/docs/build/html/classmlx_1_1core_1_1_slice_update.html
@@ -158,9 +158,9 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_softmax-members.html b/docs/build/html/classmlx_1_1core_1_1_softmax-members.html
index 6d179f2d5..8f7fa10c0 100644
--- a/docs/build/html/classmlx_1_1core_1_1_softmax-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_softmax-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Softmax | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Softmax | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_softmax.html b/docs/build/html/classmlx_1_1core_1_1_softmax.html
index d3ae3a842..7b293ae7a 100644
--- a/docs/build/html/classmlx_1_1core_1_1_softmax.html
+++ b/docs/build/html/classmlx_1_1core_1_1_softmax.html
@@ -127,9 +127,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -332,8 +332,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -342,7 +342,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Softmax::output_shapes |
+ std::vector< Shape > mlx::core::Softmax::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -358,7 +358,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_sort-members.html b/docs/build/html/classmlx_1_1core_1_1_sort-members.html
index 97728560d..5af387bd3 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sort-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sort-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Sort | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Sort | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_sort.html b/docs/build/html/classmlx_1_1core_1_1_sort.html
index ad76c6386..cd0f38691 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sort.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sort.html
@@ -127,9 +127,9 @@ Public Member Functions
void | print (std::ostream &os) override |
| Print the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -332,8 +332,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -342,7 +342,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Sort::output_shapes |
+ std::vector< Shape > mlx::core::Sort::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -358,7 +358,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_split-members.html b/docs/build/html/classmlx_1_1core_1_1_split-members.html
index a5eec811d..a7334dab7 100644
--- a/docs/build/html/classmlx_1_1core_1_1_split-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_split-members.html
@@ -101,7 +101,7 @@ $(function(){ initResizable(false); });
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override | mlx::core::Split | virtual |
operator=(const Primitive &other)=delete | mlx::core::Primitive | |
operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
+ output_shapes(const std::vector< array > &inputs) | mlx::core::Primitive | virtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_split.html b/docs/build/html/classmlx_1_1core_1_1_split.html
index 36ea497af..215d94f5e 100644
--- a/docs/build/html/classmlx_1_1core_1_1_split.html
+++ b/docs/build/html/classmlx_1_1core_1_1_split.html
@@ -139,9 +139,9 @@ Public Member Functions
const Stream & | stream () |
| The stream the primitive will run on.
|
|
-virtual std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) |
- | Get the output shapes of the primitive.
|
- |
+virtual std::vector< Shape > | output_shapes (const std::vector< array > &inputs) |
+ | Get the output shapes of the primitive.
|
+ |
virtual | ~Primitive ()=default |
|
| Primitive (const Primitive &other)=delete |
diff --git a/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html b/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html
index d7dce968d..4a89dcbaf 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Sqrt | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Sqrt | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_sqrt.html b/docs/build/html/classmlx_1_1core_1_1_sqrt.html
index ae7665543..1757f7073 100644
--- a/docs/build/html/classmlx_1_1core_1_1_sqrt.html
+++ b/docs/build/html/classmlx_1_1core_1_1_sqrt.html
@@ -124,9 +124,9 @@ Public Member Functions
std::vector< array > | vjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override |
| The vector-Jacobian product.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
@@ -332,8 +332,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -342,7 +342,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Sqrt::output_shapes |
+ std::vector< Shape > mlx::core::Sqrt::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -358,7 +358,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_square-members.html b/docs/build/html/classmlx_1_1core_1_1_square-members.html
index afb218df6..66ce10674 100644
--- a/docs/build/html/classmlx_1_1core_1_1_square-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_square-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Square | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Square | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_square.html b/docs/build/html/classmlx_1_1core_1_1_square.html
index 8b32b08ce..2a432c233 100644
--- a/docs/build/html/classmlx_1_1core_1_1_square.html
+++ b/docs/build/html/classmlx_1_1core_1_1_square.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Square::output_shapes |
+ std::vector< Shape > mlx::core::Square::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_stop_gradient-members.html b/docs/build/html/classmlx_1_1core_1_1_stop_gradient-members.html
index 68c7c526d..1d9b8f9a7 100644
--- a/docs/build/html/classmlx_1_1core_1_1_stop_gradient-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_stop_gradient-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::StopGradient | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::StopGradient | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html b/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html
index b7f20b8f4..d571be713 100644
--- a/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html
+++ b/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html
@@ -124,9 +124,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -290,8 +290,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -300,7 +300,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::StopGradient::output_shapes |
+ std::vector< Shape > mlx::core::StopGradient::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -316,7 +316,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_subtract-members.html b/docs/build/html/classmlx_1_1core_1_1_subtract-members.html
index e8ba277a4..8bf2cd270 100644
--- a/docs/build/html/classmlx_1_1core_1_1_subtract-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_subtract-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Subtract | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Subtract | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_subtract.html b/docs/build/html/classmlx_1_1core_1_1_subtract.html
index b723a566f..c597379c1 100644
--- a/docs/build/html/classmlx_1_1core_1_1_subtract.html
+++ b/docs/build/html/classmlx_1_1core_1_1_subtract.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Subtract::output_shapes |
+ std::vector< Shape > mlx::core::Subtract::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_tan-members.html b/docs/build/html/classmlx_1_1core_1_1_tan-members.html
index 87794bd5e..1b4643339 100644
--- a/docs/build/html/classmlx_1_1core_1_1_tan-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_tan-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Tan | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Tan | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_tan.html b/docs/build/html/classmlx_1_1core_1_1_tan.html
index d5e926a88..7bb932a84 100644
--- a/docs/build/html/classmlx_1_1core_1_1_tan.html
+++ b/docs/build/html/classmlx_1_1core_1_1_tan.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
- std::vector< std::vector< int > > mlx::core::Tan::output_shapes |
+ std::vector< Shape > mlx::core::Tan::output_shapes |
( |
const std::vector< array > & | inputs | ) |
|
@@ -354,7 +354,7 @@ Public Member Functions
Get the output shapes of the primitive.
This is not required to be implemented by derived classes, in which case it will throw.
-Reimplemented from mlx::core::Primitive.
+Reimplemented from mlx::core::Primitive.
diff --git a/docs/build/html/classmlx_1_1core_1_1_tanh-members.html b/docs/build/html/classmlx_1_1core_1_1_tanh-members.html
index f20e714d8..2cac2f935 100644
--- a/docs/build/html/classmlx_1_1core_1_1_tanh-members.html
+++ b/docs/build/html/classmlx_1_1core_1_1_tanh-members.html
@@ -105,7 +105,7 @@ $(function(){ initResizable(false); });
operator=(UnaryPrimitive &&other)=delete | mlx::core::UnaryPrimitive | |
mlx::core::Primitive::operator=(const Primitive &other)=delete | mlx::core::Primitive | |
mlx::core::Primitive::operator=(Primitive &&other)=delete | mlx::core::Primitive | |
- output_shapes(const std::vector< array > &inputs) override | mlx::core::Tanh | inlinevirtual |
+ output_shapes(const std::vector< array > &inputs) override | mlx::core::Tanh | inlinevirtual |
Primitive(Stream stream) | mlx::core::Primitive | inlineexplicit |
Primitive(const Primitive &other)=delete | mlx::core::Primitive | |
Primitive(Primitive &&other)=delete | mlx::core::Primitive | |
diff --git a/docs/build/html/classmlx_1_1core_1_1_tanh.html b/docs/build/html/classmlx_1_1core_1_1_tanh.html
index df0a20367..ca65580d1 100644
--- a/docs/build/html/classmlx_1_1core_1_1_tanh.html
+++ b/docs/build/html/classmlx_1_1core_1_1_tanh.html
@@ -130,9 +130,9 @@ Public Member Functions
bool | is_equivalent (const Primitive &other) const override |
| Equivalence check defaults to false unless overridden by the primitive.
|
|
-std::vector< std::vector< int > > | output_shapes (const std::vector< array > &inputs) override |
- | Get the output shapes of the primitive.
|
- |
+std::vector< Shape > | output_shapes (const std::vector< array > &inputs) override |
+ | Get the output shapes of the primitive.
|
+ |
| UnaryPrimitive (Stream stream) |
| An abstract base class for a primitive with a single output.
|
@@ -328,8 +328,8 @@ Public Member Functions
-
-◆ output_shapes()
+
+◆ output_shapes()
@@ -338,7 +338,7 @@ Public Member Functions
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |