mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
awni's commit files
This commit is contained in:
32
python/src/CMakeLists.txt
Normal file
32
python/src/CMakeLists.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
pybind11_add_module(
|
||||
core
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||
set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||
endif()
|
||||
|
||||
set_target_properties(
|
||||
core
|
||||
PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY
|
||||
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
target_link_libraries(core PRIVATE mlx)
|
||||
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
|
||||
endif()
|
468
python/src/fft.cpp
Normal file
468
python/src/fft.cpp
Normal file
@@ -0,0 +1,468 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_fft(py::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"fft", "mlx.core.fft: Fast Fourier Transforms.");
|
||||
m.def(
|
||||
"fft",
|
||||
[](const array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::fft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::fft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
One dimensional discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
n (int, optional): Size of the transformed axis. The
|
||||
corresponding axis in the input is truncated or padded with
|
||||
zeros to match ``n``. The default value is ``a.shape[axis]``.
|
||||
axis (int, optional): Axis along which to perform the FFT. The
|
||||
default is ``-1``.
|
||||
|
||||
Returns:
|
||||
array: The DFT of the input along the given axis.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifft",
|
||||
[](const array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::ifft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::ifft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
One dimensional inverse discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
n (int, optional): Size of the transformed axis. The
|
||||
corresponding axis in the input is truncated or padded with
|
||||
zeros to match ``n``. The default value is ``a.shape[axis]``.
|
||||
axis (int, optional): Axis along which to perform the FFT. The
|
||||
default is ``-1``.
|
||||
|
||||
Returns:
|
||||
array: The inverse DFT of the input along the given axis.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"fft2",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::fftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::fftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Two dimensional discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``[-2, -1]``.
|
||||
|
||||
Returns:
|
||||
array: The DFT of the input along the given axes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifft2",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::ifftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Two dimensional inverse discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``[-2, -1]``.
|
||||
|
||||
Returns:
|
||||
array: The inverse DFT of the input along the given axes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"fftn",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::fftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::fftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
n-dimensional discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``None`` in which case the FFT is over the last
|
||||
``len(s)`` axes are or all axes if ``s`` is also ``None``.
|
||||
|
||||
Returns:
|
||||
array: The DFT of the input along the given axes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifftn",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::ifftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
n-dimensional inverse discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``None`` in which case the FFT is over the last
|
||||
``len(s)`` axes or all axes if ``s`` is also ``None``.
|
||||
|
||||
Returns:
|
||||
array: The inverse DFT of the input along the given axes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfft",
|
||||
[](const array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::rfft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::rfft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
One dimensional discrete Fourier Transform on a real input.
|
||||
|
||||
The output has the same shape as the input except along ``axis`` in
|
||||
which case it has size ``n // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array. If the array is complex it will be silently
|
||||
cast to a real type.
|
||||
n (int, optional): Size of the transformed axis. The
|
||||
corresponding axis in the input is truncated or padded with
|
||||
zeros to match ``n``. The default value is ``a.shape[axis]``.
|
||||
axis (int, optional): Axis along which to perform the FFT. The
|
||||
default is ``-1``.
|
||||
|
||||
Returns:
|
||||
array: The DFT of the input along the given axis. The output
|
||||
data type will be complex.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfft",
|
||||
[](const array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::irfft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::irfft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
The inverse of :func:`rfft`.
|
||||
|
||||
The output has the same shape as the input except along ``axis`` in
|
||||
which case it has size ``n``.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
n (int, optional): Size of the transformed axis. The
|
||||
corresponding axis in the input is truncated or padded with
|
||||
zeros to match ``n // 2 + 1``. The default value is
|
||||
``a.shape[axis] // 2 + 1``.
|
||||
axis (int, optional): Axis along which to perform the FFT. The
|
||||
default is ``-1``.
|
||||
|
||||
Returns:
|
||||
array: The real array containing the inverse of :func:`rfft`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfft2",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::rfftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Two dimensional real discrete Fourier Transform.
|
||||
|
||||
The output has the same shape as the input except along the dimensions in
|
||||
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
|
||||
treated as the real axis and will have size ``s[-1] // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array. If the array is complex it will be silently
|
||||
cast to a real type.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``[-2, -1]``.
|
||||
|
||||
Returns:
|
||||
array: The real DFT of the input along the given axes. The output
|
||||
data type will be complex.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfft2",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::irfftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
The inverse of :func:`rfft2`.
|
||||
|
||||
Note the input is generally complex. The dimensions of the input
|
||||
specified in ``axes`` are padded or truncated to match the sizes
|
||||
from ``s``. The last axis in ``axes`` is treated as the real axis
|
||||
and will have size ``s[-1] // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s`` except for the last axis
|
||||
which has size ``s[-1] // 2 + 1``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``[-2, -1]``.
|
||||
|
||||
Returns:
|
||||
array: The real array containing the inverse of :func:`rfft2`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfftn",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::rfftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
n-dimensional real discrete Fourier Transform.
|
||||
|
||||
The output has the same shape as the input except along the dimensions in
|
||||
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
|
||||
treated as the real axis and will have size ``s[-1] // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array. If the array is complex it will be silently
|
||||
cast to a real type.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``None`` in which case the FFT is over the last
|
||||
``len(s)`` axes or all axes if ``s`` is also ``None``.
|
||||
|
||||
Returns:
|
||||
array: The real DFT of the input along the given axes. The output
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfftn",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::irfftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
The inverse of :func:`rfftn`.
|
||||
|
||||
Note the input is generally complex. The dimensions of the input
|
||||
specified in ``axes`` are padded or truncated to match the sizes
|
||||
from ``s``. The last axis in ``axes`` is treated as the real axis
|
||||
and will have size ``s[-1] // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``None`` in which case the FFT is over the last
|
||||
``len(s)`` axes or all axes if ``s`` is also ``None``.
|
||||
|
||||
Returns:
|
||||
array: The real array containing the inverse of :func:`rfftn`.
|
||||
)pbdoc");
|
||||
}
|
635
python/src/indexing.cpp
Normal file
635
python/src/indexing.cpp
Normal file
@@ -0,0 +1,635 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "python/src/indexing.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
|
||||
bool is_none_slice(const py::slice& in_slice) {
|
||||
return (
|
||||
py::getattr(in_slice, "start").is_none() &&
|
||||
py::getattr(in_slice, "stop").is_none() &&
|
||||
py::getattr(in_slice, "step").is_none());
|
||||
}
|
||||
|
||||
int get_slice_int(py::object obj, int default_val) {
|
||||
if (!obj.is_none()) {
|
||||
if (!py::isinstance<py::int_>(obj)) {
|
||||
throw std::invalid_argument("Slice indices must be integers or None.");
|
||||
}
|
||||
return py::cast<int>(py::cast<py::int_>(obj));
|
||||
}
|
||||
return default_val;
|
||||
}
|
||||
|
||||
void get_slice_params(
|
||||
int& starts,
|
||||
int& ends,
|
||||
int& strides,
|
||||
const py::slice& in_slice,
|
||||
int axis_size) {
|
||||
// Following numpy's convention
|
||||
// Assume n is the number of elements in the dimension being sliced.
|
||||
// Then, if i is not given it defaults to 0 for k > 0 and n - 1 for
|
||||
// k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for
|
||||
// k < 0 . If k is not given it defaults to 1
|
||||
|
||||
strides = get_slice_int(py::getattr(in_slice, "step"), 1);
|
||||
starts = get_slice_int(
|
||||
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
|
||||
ends = get_slice_int(
|
||||
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||
|
||||
// starts = (starts < 0) ? starts + axis_size : starts;
|
||||
// ends = (ends < 0) ? ends + axis_size : ends;
|
||||
}
|
||||
|
||||
array get_int_index(py::object idx, int axis_size) {
|
||||
int idx_ = py::cast<int>(idx);
|
||||
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
|
||||
|
||||
return array(idx_, uint32);
|
||||
}
|
||||
|
||||
bool is_valid_index_type(const py::object& obj) {
|
||||
return py::isinstance<py::slice>(obj) || py::isinstance<py::int_>(obj) ||
|
||||
py::isinstance<array>(obj) || obj.is_none() || py::ellipsis().is(obj);
|
||||
}
|
||||
|
||||
array mlx_get_item_slice(const array& src, const py::slice& in_slice) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// Return a copy of the array if none slice is request
|
||||
if (is_none_slice(in_slice)) {
|
||||
return src;
|
||||
}
|
||||
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> ends = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
|
||||
// Check and update slice params
|
||||
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
|
||||
return slice(src, starts, ends, strides);
|
||||
}
|
||||
|
||||
array mlx_get_item_array(const array& src, const array& indices) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
if (indices.dtype() == bool_) {
|
||||
throw std::invalid_argument("boolean indices are not yet supported");
|
||||
}
|
||||
|
||||
// If only one input array is mentioned, we set axis=0 in take
|
||||
// for parity with np
|
||||
return take(src, indices, 0);
|
||||
}
|
||||
|
||||
array mlx_get_item_int(const array& src, const py::int_& idx) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// If only one input idx is mentioned, we set axis=0 in take
|
||||
// for parity with np
|
||||
return take(src, get_int_index(idx, src.shape(0)), 0);
|
||||
}
|
||||
|
||||
array mlx_gather_nd(
|
||||
array src,
|
||||
const std::vector<py::object>& indices,
|
||||
bool gather_first,
|
||||
int& max_dims) {
|
||||
max_dims = 0;
|
||||
std::vector<array> gather_indices;
|
||||
std::vector<bool> is_slice(indices.size(), false);
|
||||
int num_slices = 0;
|
||||
// gather all the arrays
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
auto& idx = indices[i];
|
||||
|
||||
if (py::isinstance<py::slice>(idx)) {
|
||||
int start, end, stride;
|
||||
get_slice_params(start, end, stride, idx, src.shape(i));
|
||||
gather_indices.push_back(arange(start, end, stride, uint32));
|
||||
num_slices++;
|
||||
is_slice[i] = true;
|
||||
} else if (py::isinstance<py::int_>(idx)) {
|
||||
gather_indices.push_back(get_int_index(idx, src.shape(i)));
|
||||
} else if (py::isinstance<array>(idx)) {
|
||||
auto arr = py::cast<array>(idx);
|
||||
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
|
||||
gather_indices.push_back(arr);
|
||||
}
|
||||
}
|
||||
|
||||
// reshape them so that the int/array indices are first
|
||||
if (gather_first) {
|
||||
int slice_index = 0;
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
if (is_slice[i]) {
|
||||
std::vector<int> index_shape(max_dims + num_slices, 1);
|
||||
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
||||
slice_index++;
|
||||
} else {
|
||||
std::vector<int> index_shape = gather_indices[i].shape();
|
||||
index_shape.insert(index_shape.end(), num_slices, 1);
|
||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// reshape them so that the int/array indices are last
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
if (i < num_slices) {
|
||||
std::vector<int> index_shape(max_dims + num_slices, 1);
|
||||
index_shape[i] = gather_indices[i].shape(0);
|
||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do the gather
|
||||
std::vector<int> axes(indices.size());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::vector<int> slice_sizes = src.shape();
|
||||
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
|
||||
src = gather(src, gather_indices, axes, slice_sizes);
|
||||
|
||||
// Squeeze the dims
|
||||
std::vector<int> out_shape;
|
||||
out_shape.insert(
|
||||
out_shape.end(),
|
||||
src.shape().begin(),
|
||||
src.shape().begin() + max_dims + num_slices);
|
||||
out_shape.insert(
|
||||
out_shape.end(),
|
||||
src.shape().begin() + max_dims + num_slices + indices.size(),
|
||||
src.shape().end());
|
||||
src = reshape(src, out_shape);
|
||||
|
||||
return src;
|
||||
}
|
||||
|
||||
array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
// No indices make this a noop
|
||||
if (entries.size() == 0) {
|
||||
return src;
|
||||
}
|
||||
|
||||
// The plan is as follows:
|
||||
// 1. Replace the ellipsis with a series of slice(None)
|
||||
// 2. Loop over the indices and calculate the gather indices
|
||||
// 3. Calculate the remaining slices and reshapes
|
||||
|
||||
// Ellipsis handling
|
||||
std::vector<py::object> indices;
|
||||
{
|
||||
int non_none_indices_before = 0;
|
||||
int non_none_indices_after = 0;
|
||||
std::vector<py::object> r_indices;
|
||||
int i = 0;
|
||||
for (; i < entries.size(); i++) {
|
||||
auto idx = entries[i];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (!py::ellipsis().is(idx)) {
|
||||
indices.push_back(idx);
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (int j = entries.size() - 1; j > i; j--) {
|
||||
auto idx = entries[j];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (py::ellipsis().is(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"An index can only have a single ellipsis (...)");
|
||||
}
|
||||
r_indices.push_back(idx);
|
||||
non_none_indices_after += !idx.is_none();
|
||||
}
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(py::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
|
||||
}
|
||||
|
||||
// Check for the number of indices passed
|
||||
{
|
||||
int cnt = src.ndim();
|
||||
for (auto& idx : indices) {
|
||||
if (!idx.is_none()) {
|
||||
cnt--;
|
||||
}
|
||||
}
|
||||
if (cnt < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// Gather handling
|
||||
//
|
||||
// Check whether we have arrays or integer indices and delegate to gather_nd
|
||||
// after removing the slices at the end and all Nones.
|
||||
std::vector<py::object> remaining_indices;
|
||||
bool have_array = false;
|
||||
{
|
||||
// First check whether the results of gather are going to be 1st or
|
||||
// normally in between.
|
||||
bool have_non_array = false;
|
||||
bool gather_first = false;
|
||||
for (auto& idx : indices) {
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
if (have_array && have_non_array) {
|
||||
gather_first = true;
|
||||
break;
|
||||
}
|
||||
have_array = true;
|
||||
} else {
|
||||
have_non_array |= have_array;
|
||||
}
|
||||
}
|
||||
|
||||
if (have_array) {
|
||||
int last_array;
|
||||
// Then find the last array
|
||||
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
|
||||
auto& idx = indices[last_array];
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<py::object> gather_indices;
|
||||
for (int i = 0; i <= last_array; i++) {
|
||||
auto& idx = indices[i];
|
||||
if (!idx.is_none()) {
|
||||
gather_indices.push_back(idx);
|
||||
}
|
||||
}
|
||||
int max_dims;
|
||||
src = mlx_gather_nd(src, gather_indices, gather_first, max_dims);
|
||||
|
||||
// Reassemble the indices for the slicing or reshaping if there are any
|
||||
if (gather_first) {
|
||||
for (int i = 0; i < max_dims; i++) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
}
|
||||
for (int i = 0; i < last_array; i++) {
|
||||
auto& idx = indices[i];
|
||||
if (idx.is_none()) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
} else if (py::isinstance<py::slice>(idx)) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
}
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
auto& idx = indices[i];
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
break;
|
||||
} else if (idx.is_none()) {
|
||||
remaining_indices.push_back(idx);
|
||||
} else {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < max_dims; i++) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (have_array && remaining_indices.empty()) {
|
||||
return src;
|
||||
}
|
||||
if (remaining_indices.empty()) {
|
||||
remaining_indices = indices;
|
||||
}
|
||||
|
||||
// Slice handling
|
||||
{
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> ends = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
int axis = 0;
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (!idx.is_none()) {
|
||||
get_slice_params(
|
||||
starts[axis], ends[axis], strides[axis], idx, ends[axis]);
|
||||
axis++;
|
||||
}
|
||||
}
|
||||
src = slice(src, starts, ends, strides);
|
||||
}
|
||||
|
||||
// Unsqueeze handling
|
||||
if (remaining_indices.size() > src.ndim()) {
|
||||
std::vector<int> out_shape;
|
||||
int axis = 0;
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (idx.is_none()) {
|
||||
out_shape.push_back(1);
|
||||
} else {
|
||||
out_shape.push_back(src.shape(axis++));
|
||||
}
|
||||
}
|
||||
src = reshape(src, out_shape);
|
||||
}
|
||||
|
||||
return src;
|
||||
}
|
||||
|
||||
array mlx_get_item(const array& src, const py::object& obj) {
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_get_item_slice(src, obj);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_get_item_array(src, py::cast<array>(obj));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_get_item_int(src, obj);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_get_item_nd(src, obj);
|
||||
} else if (obj.is_none()) {
|
||||
std::vector<int> s(1, 1);
|
||||
s.insert(s.end(), src.shape().begin(), src.shape().end());
|
||||
return reshape(src, s);
|
||||
}
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
|
||||
array mlx_set_item_int(
|
||||
const array& src,
|
||||
const py::int_& idx,
|
||||
const array& update) {
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// Remove any leading singleton dimensions from the update
|
||||
// and then broadcast update to shape of src[0, ...]
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto shape = src.shape();
|
||||
shape[0] = 1;
|
||||
return scatter(
|
||||
src,
|
||||
get_int_index(idx, src.shape(0)),
|
||||
broadcast_to(reshape(update, up_shape), shape),
|
||||
0);
|
||||
}
|
||||
|
||||
array mlx_set_item_array(
|
||||
const array& src,
|
||||
const array& indices,
|
||||
const array& update) {
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// Remove any leading singleton dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
|
||||
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
||||
up_shape = indices.shape();
|
||||
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||
up = reshape(up, up_shape);
|
||||
|
||||
return scatter(src, indices, up, 0);
|
||||
}
|
||||
|
||||
array mlx_set_item_slice(
|
||||
const array& src,
|
||||
const py::slice& in_slice,
|
||||
const array& update) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// If none slice is requested broadcast the update
|
||||
// to the src size and return it.
|
||||
if (is_none_slice(in_slice)) {
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
return broadcast_to(reshape(update, up_shape), src.shape());
|
||||
}
|
||||
|
||||
int start = 0;
|
||||
int end = src.shape(0);
|
||||
int stride = 1;
|
||||
|
||||
// Check and update slice params
|
||||
get_slice_params(start, end, stride, in_slice, end);
|
||||
|
||||
return mlx_set_item_array(src, arange(start, end, stride, uint32), update);
|
||||
}
|
||||
|
||||
array mlx_set_item_nd(
|
||||
const array& src,
|
||||
const py::tuple& entries,
|
||||
const array& update) {
|
||||
std::vector<py::object> indices;
|
||||
int non_none_indices = 0;
|
||||
|
||||
// Expand ellipses into a series of ':' slices
|
||||
{
|
||||
int non_none_indices_before = 0;
|
||||
int non_none_indices_after = 0;
|
||||
bool has_ellipsis = false;
|
||||
int indices_before = 0;
|
||||
for (int i = 0; i < entries.size(); ++i) {
|
||||
auto idx = entries[i];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
} else if (!py::ellipsis().is(idx)) {
|
||||
if (!has_ellipsis) {
|
||||
indices_before++;
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
non_none_indices_after += !idx.is_none();
|
||||
}
|
||||
indices.push_back(idx);
|
||||
} else if (has_ellipsis) {
|
||||
throw std::invalid_argument(
|
||||
"An index can only have a single ellipsis (...)");
|
||||
} else {
|
||||
has_ellipsis = true;
|
||||
}
|
||||
}
|
||||
if (has_ellipsis) {
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.insert(
|
||||
indices.begin() + indices_before, py::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
non_none_indices = src.ndim();
|
||||
} else {
|
||||
non_none_indices = non_none_indices_before + non_none_indices_after;
|
||||
}
|
||||
}
|
||||
|
||||
if (non_none_indices > src.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Remove leading singletons dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++) {
|
||||
};
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
|
||||
// If no non-None indices return the broadcasted update
|
||||
if (non_none_indices == 0) {
|
||||
return broadcast_to(up, src.shape());
|
||||
}
|
||||
|
||||
unsigned long max_dim = 0;
|
||||
bool arrays_first = false;
|
||||
int num_slices = 0;
|
||||
int num_arrays = 0;
|
||||
{
|
||||
bool have_array = false;
|
||||
bool have_non_array = false;
|
||||
for (auto& idx : indices) {
|
||||
if (py::isinstance<py::slice>(idx) || idx.is_none()) {
|
||||
have_non_array = have_array;
|
||||
num_slices++;
|
||||
} else if (py::isinstance<array>(idx)) {
|
||||
have_array = true;
|
||||
if (have_array && have_non_array) {
|
||||
arrays_first = true;
|
||||
}
|
||||
max_dim = std::max(py::cast<array>(idx).ndim(), max_dim);
|
||||
num_arrays++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> arr_indices;
|
||||
int slice_num = 0;
|
||||
int array_num = 0;
|
||||
int ax = 0;
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (py::isinstance<py::slice>(pyidx)) {
|
||||
int start, end, stride;
|
||||
get_slice_params(start, end, stride, pyidx, src.shape(ax++));
|
||||
auto idx = arange(start, end, stride, uint32);
|
||||
std::vector<int> idx_shape(max_dim + num_slices, 1);
|
||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||
slice_num++;
|
||||
idx_shape[loc] = idx.size();
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
} else if (py::isinstance<py::int_>(pyidx)) {
|
||||
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
|
||||
} else if (pyidx.is_none()) {
|
||||
slice_num++;
|
||||
} else if (py::isinstance<array>(pyidx)) {
|
||||
ax++;
|
||||
auto idx = py::cast<array>(pyidx);
|
||||
std::vector<int> idx_shape;
|
||||
if (!arrays_first) {
|
||||
idx_shape.insert(idx_shape.end(), slice_num, 1);
|
||||
}
|
||||
idx_shape.insert(idx_shape.end(), max_dim - idx.ndim(), 1);
|
||||
idx_shape.insert(idx_shape.end(), idx.shape().begin(), idx.shape().end());
|
||||
idx_shape.insert(
|
||||
idx_shape.end(), num_slices - (arrays_first ? 0 : slice_num), 1);
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
if (!arrays_first && ++array_num == num_arrays) {
|
||||
slice_num += max_dim;
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
}
|
||||
|
||||
arr_indices = broadcast_arrays(arr_indices);
|
||||
up_shape = arr_indices[0].shape();
|
||||
up_shape.insert(
|
||||
up_shape.end(),
|
||||
src.shape().begin() + non_none_indices,
|
||||
src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(
|
||||
up_shape.begin() + arr_indices[0].ndim(), non_none_indices, 1);
|
||||
up = reshape(up, up_shape);
|
||||
|
||||
std::vector<int> axes(arr_indices.size(), 0);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return scatter(src, arr_indices, up, axes);
|
||||
}
|
||||
|
||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
|
||||
auto vals = to_array(v, src.dtype());
|
||||
auto impl = [&src, &obj, &vals]() {
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_set_item_slice(src, obj, vals);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_set_item_array(src, py::cast<array>(obj), vals);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_set_item_int(src, obj, vals);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_set_item_nd(src, obj, vals);
|
||||
} else if (obj.is_none()) {
|
||||
return broadcast_to(vals, src.shape());
|
||||
}
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
};
|
||||
auto out = impl();
|
||||
src.overwrite_descriptor(out);
|
||||
}
|
12
python/src/indexing.h
Normal file
12
python/src/indexing.h
Normal file
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlx::core;
|
||||
|
||||
array mlx_get_item(const array& src, const py::object& obj);
|
||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
|
290
python/src/load.cpp
Normal file
290
python/src/load.cpp
Normal file
@@ -0,0 +1,290 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/load.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool is_istream_object(const py::object& file) {
|
||||
return py::hasattr(file, "read") && py::hasattr(file, "seek") &&
|
||||
py::hasattr(file, "tell") && py::hasattr(file, "closed");
|
||||
}
|
||||
|
||||
bool is_ostream_object(const py::object& file) {
|
||||
return py::hasattr(file, "write") && py::hasattr(file, "seek") &&
|
||||
py::hasattr(file, "tell") && py::hasattr(file, "closed");
|
||||
}
|
||||
|
||||
bool is_zip_file(const py::module_& zipfile, const py::object& file) {
|
||||
if (is_istream_object(file)) {
|
||||
auto st_pos = file.attr("tell")();
|
||||
bool r = (zipfile.attr("is_zipfile")(file)).cast<bool>();
|
||||
file.attr("seek")(st_pos, 0);
|
||||
return r;
|
||||
}
|
||||
return zipfile.attr("is_zipfile")(file).cast<bool>();
|
||||
}
|
||||
|
||||
class ZipFileWrapper {
|
||||
public:
|
||||
ZipFileWrapper(
|
||||
const py::module_& zipfile,
|
||||
const py::object& file,
|
||||
char mode = 'r',
|
||||
int compression = 0)
|
||||
: zipfile_module_(zipfile),
|
||||
zipfile_object_(zipfile.attr("ZipFile")(
|
||||
file,
|
||||
"mode"_a = mode,
|
||||
"compression"_a = compression,
|
||||
"allowZip64"_a = true)),
|
||||
files_list_(zipfile_object_.attr("namelist")()),
|
||||
open_func_(zipfile_object_.attr("open")),
|
||||
read_func_(zipfile_object_.attr("read")),
|
||||
close_func_(zipfile_object_.attr("close")) {}
|
||||
|
||||
std::vector<std::string> namelist() const {
|
||||
return files_list_.cast<std::vector<std::string>>();
|
||||
}
|
||||
|
||||
py::object open(const std::string& key, char mode = 'r') {
|
||||
// Following numpy :
|
||||
// https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47
|
||||
if (mode == 'w') {
|
||||
return open_func_(key, "mode"_a = mode, "force_zip64"_a = true);
|
||||
}
|
||||
return open_func_(key, "mode"_a = mode);
|
||||
}
|
||||
|
||||
private:
|
||||
py::module_ zipfile_module_;
|
||||
py::object zipfile_object_;
|
||||
py::list files_list_;
|
||||
py::object open_func_;
|
||||
py::object read_func_;
|
||||
py::object close_func_;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class PyFileReader : public io::Reader {
|
||||
public:
|
||||
PyFileReader(py::object file)
|
||||
: pyistream_(file),
|
||||
readinto_func_(file.attr("readinto")),
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
bool is_open() const override {
|
||||
return !pyistream_.attr("closed").cast<bool>();
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return !pyistream_.is_none();
|
||||
}
|
||||
|
||||
size_t tell() const override {
|
||||
return tell_func_().cast<size_t>();
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
py::object bytes_read =
|
||||
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
|
||||
throw std::runtime_error("[load] Failed to read from python stream");
|
||||
}
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
return "python file object";
|
||||
}
|
||||
|
||||
private:
|
||||
py::object pyistream_;
|
||||
py::object readinto_func_;
|
||||
py::object seek_func_;
|
||||
py::object tell_func_;
|
||||
};
|
||||
|
||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
||||
py::module_ zipfile = py::module_::import("zipfile");
|
||||
|
||||
// Assume .npz file if it is zipped
|
||||
if (is_zip_file(zipfile, file)) {
|
||||
// Output dictionary filename in zip -> loaded array
|
||||
std::unordered_map<std::string, array> array_dict;
|
||||
|
||||
// Create python ZipFile object
|
||||
ZipFileWrapper zipfile_object(zipfile, file);
|
||||
for (const std::string& st : zipfile_object.namelist()) {
|
||||
// Open zip file as a python file stream
|
||||
py::object sub_file = zipfile_object.open(st);
|
||||
|
||||
// Create array from python fille stream
|
||||
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
|
||||
|
||||
// Remove .npy from file if it is there
|
||||
auto key = st;
|
||||
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
|
||||
key = st.substr(0, st.length() - 4);
|
||||
|
||||
// Add array to dict
|
||||
array_dict.insert({key, arr});
|
||||
}
|
||||
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
for (auto& [key, arr] : array_dict) {
|
||||
arr.eval();
|
||||
}
|
||||
|
||||
return {array_dict};
|
||||
} else if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
||||
return {load(py::cast<std::string>(file), s)};
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
||||
arr.eval();
|
||||
return {arr};
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[load] Input must be a file-like object, string, or pathlib.Path");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Saving
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class PyFileWriter : public io::Writer {
|
||||
public:
|
||||
PyFileWriter(py::object file)
|
||||
: pyostream_(file),
|
||||
write_func_(file.attr("write")),
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
bool is_open() const override {
|
||||
return !pyostream_.attr("closed").cast<bool>();
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return !pyostream_.is_none();
|
||||
}
|
||||
|
||||
size_t tell() const override {
|
||||
return tell_func_().cast<size_t>();
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void write(const char* data, size_t n) override {
|
||||
py::object bytes_written =
|
||||
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
|
||||
throw std::runtime_error("[load] Failed to write to python stream");
|
||||
}
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
return "python file object";
|
||||
}
|
||||
|
||||
private:
|
||||
py::object pyostream_;
|
||||
py::object write_func_;
|
||||
py::object seek_func_;
|
||||
py::object tell_func_;
|
||||
};
|
||||
|
||||
void mlx_save_helper(py::object file, array a, bool retain_graph) {
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save(py::cast<std::string>(file), a, retain_graph);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
save(std::make_shared<PyFileWriter>(file), a, retain_graph);
|
||||
return;
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[save] Input must be a file-like object, string, or pathlib.Path");
|
||||
}
|
||||
|
||||
void mlx_savez_helper(
|
||||
py::object file_,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
bool compressed) {
|
||||
// Add .npz to the end of the filename if not already there
|
||||
py::object file = file_;
|
||||
|
||||
if (py::isinstance<py::str>(file_)) {
|
||||
std::string fname = file_.cast<std::string>();
|
||||
|
||||
// Add .npz to file name if it is not there
|
||||
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
|
||||
fname += ".npz";
|
||||
|
||||
file = py::str(fname);
|
||||
}
|
||||
|
||||
// Collect args and kwargs
|
||||
auto arrays_dict = kwargs.cast<std::unordered_map<std::string, array>>();
|
||||
auto arrays_list = args.cast<std::vector<array>>();
|
||||
|
||||
for (int i = 0; i < arrays_list.size(); i++) {
|
||||
std::string arr_name = "arr_" + std::to_string(i);
|
||||
|
||||
if (arrays_dict.count(arr_name) > 0) {
|
||||
throw std::invalid_argument(
|
||||
"[savez] Cannot use un-named variables and keyword " + arr_name);
|
||||
}
|
||||
|
||||
arrays_dict.insert({arr_name, arrays_list[i]});
|
||||
}
|
||||
|
||||
// Create python ZipFile object depending on compression
|
||||
py::module_ zipfile = py::module_::import("zipfile");
|
||||
int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast<int>()
|
||||
: zipfile.attr("ZIP_STORED").cast<int>();
|
||||
char mode = 'w';
|
||||
ZipFileWrapper zipfile_object(zipfile, file, mode, compression);
|
||||
|
||||
// Save each array
|
||||
for (auto [k, a] : arrays_dict) {
|
||||
std::string fname = k + ".npy";
|
||||
auto py_ostream = zipfile_object.open(fname, 'w');
|
||||
save(std::make_shared<PyFileWriter>(py_ostream), a);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
2422
python/src/ops.cpp
Normal file
2422
python/src/ops.cpp
Normal file
File diff suppressed because it is too large
Load Diff
289
python/src/random.cpp
Normal file
289
python/src/random.cpp
Normal file
@@ -0,0 +1,289 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::random;
|
||||
|
||||
void init_random(py::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"random",
|
||||
"mlx.core.random: functionality related to random number generation");
|
||||
m.def(
|
||||
"seed",
|
||||
&seed,
|
||||
"seed"_a,
|
||||
R"pbdoc(
|
||||
Seed the global PRNG.
|
||||
|
||||
Args:
|
||||
seed (int): Seed for the global PRNG.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"key",
|
||||
&key,
|
||||
"seed"_a,
|
||||
R"pbdoc(
|
||||
Get a PRNG key from a seed.
|
||||
|
||||
Args:
|
||||
seed (int): Seed for the PRNG.
|
||||
|
||||
Returns:
|
||||
array: The PRNG key array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"split",
|
||||
py::overload_cast<const array&, int, StreamOrDevice>(&random::split),
|
||||
"key"_a,
|
||||
"num"_a = 2,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Split a PRNG key into sub keys.
|
||||
|
||||
Args:
|
||||
key (array): Input key to split.
|
||||
num (int, optional): Number of sub keys. Default is 2.
|
||||
|
||||
Returns:
|
||||
array: The array of sub keys with ``num`` as its first dimension.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"uniform",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
Dtype type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
return uniform(to_array(low), to_array(high), shape, type, key, s);
|
||||
},
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = float32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate uniformly distributed random numbers.
|
||||
|
||||
The values are sampled uniformly in the half-open interval ``[low, high)``.
|
||||
The lower and upper bound can be scalars or arrays and must be
|
||||
broadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
|
||||
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
|
||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||
|
||||
Returns:
|
||||
array: The output array random values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"normal",
|
||||
[](const std::vector<int>& shape,
|
||||
Dtype type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) { return normal(shape, type, key, s); },
|
||||
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = float32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
Args:
|
||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"randint",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
Dtype type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
return randint(to_array(low), to_array(high), shape, type, key, s);
|
||||
},
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = int32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate random integers from the given interval.
|
||||
|
||||
The values are sampled with equal probability from the integers in
|
||||
half-open interval ``[low, high)``. The lower and upper bound can be
|
||||
scalars or arrays and must be roadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
low (scalar or array): Lower bound of the interval.
|
||||
high (scalar or array): Upper bound of the interval.
|
||||
shape (list(int), optional): Shape of the output. Defaults to ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Defaults to ``int32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The array of random integers.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
auto p = to_array(p_);
|
||||
if (shape.has_value()) {
|
||||
return bernoulli(p, shape.value(), key, s);
|
||||
} else {
|
||||
return bernoulli(p, key, s);
|
||||
}
|
||||
},
|
||||
"p"_a = 0.5,
|
||||
"shape"_a = none,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate Bernoulli random values.
|
||||
|
||||
The values are sampled from the bernoulli distribution with parameter
|
||||
``p``. The parameter ``p`` can be a :obj:`float` or :obj:`array` and
|
||||
must be broadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
p (float or array, optional): Parameter of the Bernoulli
|
||||
distribution. Default is 0.5.
|
||||
shape (list(int), optional): Shape of the output. The default
|
||||
shape is ``p.shape``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The array of random integers.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"truncated_normal",
|
||||
[](const ScalarOrArray& lower_,
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
auto lower = to_array(lower_);
|
||||
auto upper = to_array(upper_);
|
||||
if (shape_.has_value()) {
|
||||
return truncated_normal(lower, upper, shape_.value(), dtype, key, s);
|
||||
} else {
|
||||
return truncated_normal(lower, upper, dtype, key, s);
|
||||
}
|
||||
},
|
||||
"lower"_a,
|
||||
"upper"_a,
|
||||
"shape"_a = none,
|
||||
"dtype"_a = float32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate values from a truncated normal distribution.
|
||||
|
||||
The values are sampled from the truncated normal distribution
|
||||
on the domain ``(lower, upper)``. The bounds ``lower`` and ``upper``
|
||||
can be scalars or arrays and must be broadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
lower (scalar or array): Lower bound of the domain.
|
||||
upper (scalar or array): Upper bound of the domain.
|
||||
shape (list(int), optional): The shape of the output.
|
||||
Default is ``()``.
|
||||
dtype (Dtype, optinoal): The data type of the output.
|
||||
Default is ``float32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gumbel",
|
||||
&gumbel,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = float32,
|
||||
"stream"_a = none,
|
||||
"key"_a = none,
|
||||
R"pbdoc(
|
||||
Sample from the standard Gumbel distribution.
|
||||
|
||||
The values are sampled from a standard Gumbel distribution
|
||||
which CDF ``exp(-exp(-x))``.
|
||||
|
||||
Args:
|
||||
shape (list(int)): The shape of the output.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The :class:`array` with shape ``shape`` and
|
||||
distributed according to the Gumbel distribution
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"categorical",
|
||||
[](const array& logits,
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
if (shape.has_value() && num_samples.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[categorical] At most one of shape or num_samples can be specified.");
|
||||
} else if (shape.has_value()) {
|
||||
return categorical(logits, axis, shape.value(), key, s);
|
||||
} else if (num_samples.has_value()) {
|
||||
return categorical(logits, axis, num_samples.value(), key, s);
|
||||
} else {
|
||||
return categorical(logits, axis, key, s);
|
||||
}
|
||||
},
|
||||
"logits"_a,
|
||||
"axis"_a = -1,
|
||||
"shape"_a = none,
|
||||
"num_samples"_a = none,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Sample from a categorical distribution.
|
||||
|
||||
The values are sampled from the categorical distribution specified by
|
||||
the unnormalized values in ``logits``. Note, at most one of ``shape``
|
||||
or ``num_samples`` can be specified. If both are ``None``, the output
|
||||
has the same shape as ``logits`` with the ``axis`` dimension removed.
|
||||
|
||||
Args:
|
||||
logits (array): The *unnormalized* categorical distribution(s).
|
||||
axis (int, optional): The axis which specifies the distribution.
|
||||
Default is ``-1``.
|
||||
shape (list(int), optional): The shape of the output. This must
|
||||
be broadcast compatable with ``logits.shape`` with the ``axis``
|
||||
dimension removed. Default: ``None``
|
||||
num_samples (int, optional): The number of samples to draw from each
|
||||
of the categorical distributions in ``logits``. The output will have
|
||||
``num_samples`` in the last dimension. Default: ``None``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The ``shape``-sized output array with type ``uint32``.
|
||||
)pbdoc");
|
||||
}
|
71
python/src/utils.h
Normal file
71
python/src/utils.h
Normal file
@@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
#include <variant>
|
||||
|
||||
#include <pybind11/complex.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||
using ScalarOrArray =
|
||||
std::variant<py::bool_, py::int_, py::float_, std::complex<float>, array>;
|
||||
static constexpr std::monostate none{};
|
||||
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
std::vector<int> axes;
|
||||
if (std::holds_alternative<std::monostate>(v)) {
|
||||
axes.resize(dims);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
} else if (auto pv = std::get_if<int>(&v); pv) {
|
||||
axes.push_back(*pv);
|
||||
} else {
|
||||
axes = std::get<std::vector<int>>(v);
|
||||
}
|
||||
return axes;
|
||||
}
|
||||
|
||||
inline array to_array(
|
||||
const ScalarOrArray& v,
|
||||
std::optional<Dtype> dtype = std::nullopt) {
|
||||
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
||||
return array(py::cast<bool>(*pv), dtype.value_or(bool_));
|
||||
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(int32);
|
||||
// bool_ is an exception and is always promoted
|
||||
return array(py::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
|
||||
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(float32);
|
||||
return array(
|
||||
py::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), complex64);
|
||||
} else {
|
||||
return std::get<array>(v);
|
||||
}
|
||||
}
|
||||
|
||||
inline std::pair<array, array> to_arrays(
|
||||
const ScalarOrArray& a,
|
||||
const ScalarOrArray& b) {
|
||||
// Four cases:
|
||||
// - If both a and b are arrays leave their types alone
|
||||
// - If a is an array but b is not, treat b as a weak python type
|
||||
// - If b is an array but a is not, treat a as a weak python type
|
||||
// - If neither is an array convert to arrays but leave their types alone
|
||||
if (auto pa = std::get_if<array>(&a); pa) {
|
||||
if (auto pb = std::get_if<array>(&b); pb) {
|
||||
return {*pa, *pb};
|
||||
}
|
||||
return {*pa, to_array(b, pa->dtype())};
|
||||
} else if (auto pb = std::get_if<array>(&b); pb) {
|
||||
return {to_array(a, pb->dtype()), *pb};
|
||||
} else {
|
||||
return {to_array(a), to_array(b)};
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user