More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -79,7 +79,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"fft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -115,7 +115,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"ifft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -151,7 +151,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"fftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -188,7 +188,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"ifftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -294,7 +294,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"rfft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -336,7 +336,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"irfft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -378,7 +378,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"rfftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -420,7 +420,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"irfftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {