mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 23:21:16 +08:00
Adds device context manager (#679)
This commit is contained in:
parent
ccf1645995
commit
35431a4ac8
@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
|
@ -26,6 +26,7 @@ extensions = [
|
|||||||
|
|
||||||
python_use_unqualified_type_names = True
|
python_use_unqualified_type_names = True
|
||||||
autosummary_generate = True
|
autosummary_generate = True
|
||||||
|
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||||
|
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
"https://docs.python.org/3": None,
|
"https://docs.python.org/3": None,
|
||||||
|
@ -9,9 +9,10 @@ Devices and Streams
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
Device
|
Device
|
||||||
|
Stream
|
||||||
default_device
|
default_device
|
||||||
set_default_device
|
set_default_device
|
||||||
Stream
|
|
||||||
default_stream
|
default_stream
|
||||||
new_stream
|
new_stream
|
||||||
set_default_stream
|
set_default_stream
|
||||||
|
stream
|
||||||
|
10
mlx/ops.cpp
10
mlx/ops.cpp
@ -59,16 +59,6 @@ Dtype at_least_float(const Dtype& d) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Stream to_stream(StreamOrDevice s) {
|
|
||||||
if (std::holds_alternative<std::monostate>(s)) {
|
|
||||||
return default_stream(default_device());
|
|
||||||
} else if (std::holds_alternative<Device>(s)) {
|
|
||||||
return default_stream(std::get<Device>(s));
|
|
||||||
} else {
|
|
||||||
return std::get<Stream>(s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
array arange(
|
array arange(
|
||||||
double start,
|
double start,
|
||||||
double stop,
|
double stop,
|
||||||
|
@ -3,18 +3,14 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <variant>
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
|
||||||
|
|
||||||
Stream to_stream(StreamOrDevice s);
|
|
||||||
|
|
||||||
/** Creation operations */
|
/** Creation operations */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -7,6 +7,16 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
Stream to_stream(StreamOrDevice s) {
|
||||||
|
if (std::holds_alternative<std::monostate>(s)) {
|
||||||
|
return default_stream(default_device());
|
||||||
|
} else if (std::holds_alternative<Device>(s)) {
|
||||||
|
return default_stream(std::get<Device>(s));
|
||||||
|
} else {
|
||||||
|
return std::get<Stream>(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void PrintFormatter::print(std::ostream& os, bool val) {
|
void PrintFormatter::print(std::ostream& os, bool val) {
|
||||||
if (capitalize_bool) {
|
if (capitalize_bool) {
|
||||||
os << (val ? "True" : "False");
|
os << (val ? "True" : "False");
|
||||||
|
26
mlx/utils.h
26
mlx/utils.h
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
#include "array.h"
|
#include "array.h"
|
||||||
#include "device.h"
|
#include "device.h"
|
||||||
#include "dtype.h"
|
#include "dtype.h"
|
||||||
@ -9,6 +11,30 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||||
|
Stream to_stream(StreamOrDevice s);
|
||||||
|
|
||||||
|
struct StreamContext {
|
||||||
|
public:
|
||||||
|
StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
|
||||||
|
if (std::holds_alternative<std::monostate>(s)) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||||
|
}
|
||||||
|
auto _s = to_stream(s);
|
||||||
|
set_default_device(_s.device);
|
||||||
|
set_default_stream(_s);
|
||||||
|
}
|
||||||
|
|
||||||
|
~StreamContext() {
|
||||||
|
set_default_device(_stream.device);
|
||||||
|
set_default_stream(_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Stream _stream;
|
||||||
|
};
|
||||||
|
|
||||||
struct PrintFormatter {
|
struct PrintFormatter {
|
||||||
inline void print(std::ostream& os, bool val);
|
inline void print(std::ostream& os, bool val);
|
||||||
inline void print(std::ostream& os, int16_t val);
|
inline void print(std::ostream& os, int16_t val);
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ pybind11_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||||
|
@ -12,7 +12,8 @@ using namespace py::literals;
|
|||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
void init_device(py::module_& m) {
|
void init_device(py::module_& m) {
|
||||||
auto device_class = py::class_<Device>(m, "Device");
|
auto device_class = py::class_<Device>(
|
||||||
|
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
|
||||||
py::enum_<Device::DeviceType>(m, "DeviceType")
|
py::enum_<Device::DeviceType>(m, "DeviceType")
|
||||||
.value("cpu", Device::DeviceType::cpu)
|
.value("cpu", Device::DeviceType::cpu)
|
||||||
.value("gpu", Device::DeviceType::gpu)
|
.value("gpu", Device::DeviceType::gpu)
|
||||||
@ -39,6 +40,13 @@ void init_device(py::module_& m) {
|
|||||||
|
|
||||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||||
|
|
||||||
m.def("default_device", &default_device);
|
m.def(
|
||||||
m.def("set_default_device", &set_default_device, "device"_a);
|
"default_device",
|
||||||
|
&default_device,
|
||||||
|
R"pbdoc(Get the default device.)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"set_default_device",
|
||||||
|
&set_default_device,
|
||||||
|
"device"_a,
|
||||||
|
R"pbdoc(Set the default device.)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ void init_fft(py::module_&);
|
|||||||
void init_linalg(py::module_&);
|
void init_linalg(py::module_&);
|
||||||
void init_constants(py::module_&);
|
void init_constants(py::module_&);
|
||||||
void init_extensions(py::module_&);
|
void init_extensions(py::module_&);
|
||||||
|
void init_utils(py::module_&);
|
||||||
|
|
||||||
PYBIND11_MODULE(core, m) {
|
PYBIND11_MODULE(core, m) {
|
||||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||||
@ -35,5 +36,7 @@ PYBIND11_MODULE(core, m) {
|
|||||||
init_linalg(m);
|
init_linalg(m);
|
||||||
init_constants(m);
|
init_constants(m);
|
||||||
init_extensions(m);
|
init_extensions(m);
|
||||||
|
init_utils(m);
|
||||||
|
|
||||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,12 @@ using namespace py::literals;
|
|||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
void init_stream(py::module_& m) {
|
void init_stream(py::module_& m) {
|
||||||
py::class_<Stream>(m, "Stream")
|
py::class_<Stream>(
|
||||||
|
m,
|
||||||
|
"Stream",
|
||||||
|
R"pbdoc(
|
||||||
|
A stream for running operations on a given device.
|
||||||
|
)pbdoc")
|
||||||
.def(py::init<int, Device>(), "index"_a, "device"_a)
|
.def(py::init<int, Device>(), "index"_a, "device"_a)
|
||||||
.def_readonly("device", &Stream::device)
|
.def_readonly("device", &Stream::device)
|
||||||
.def(
|
.def(
|
||||||
@ -28,7 +33,27 @@ void init_stream(py::module_& m) {
|
|||||||
|
|
||||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||||
|
|
||||||
m.def("default_stream", &default_stream, "device"_a);
|
m.def(
|
||||||
m.def("set_default_stream", &set_default_stream, "stream"_a);
|
"default_stream",
|
||||||
m.def("new_stream", &new_stream, "device"_a);
|
&default_stream,
|
||||||
|
"device"_a,
|
||||||
|
R"pbdoc(Get the device's default stream.)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"set_default_stream",
|
||||||
|
&set_default_stream,
|
||||||
|
"stream"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Set the default stream.
|
||||||
|
|
||||||
|
This will make the given stream the default for the
|
||||||
|
streams device. It will not change the default device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream (stream): Stream to make the default.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"new_stream",
|
||||||
|
&new_stream,
|
||||||
|
"device"_a,
|
||||||
|
R"pbdoc(Make a new stream on the given device.)pbdoc");
|
||||||
}
|
}
|
||||||
|
81
python/src/utils.cpp
Normal file
81
python/src/utils.cpp
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace py::literals;
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
// Slightly different from the original, with python context on init we are not
|
||||||
|
// in the context yet. Only create the inner context on enter then delete on
|
||||||
|
// exit.
|
||||||
|
class PyStreamContext {
|
||||||
|
public:
|
||||||
|
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
|
||||||
|
if (std::holds_alternative<std::monostate>(s)) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||||
|
}
|
||||||
|
_s = s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void enter() {
|
||||||
|
_inner = new StreamContext(_s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void exit() {
|
||||||
|
if (_inner != nullptr) {
|
||||||
|
delete _inner;
|
||||||
|
_inner = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
StreamOrDevice _s;
|
||||||
|
StreamContext* _inner;
|
||||||
|
};
|
||||||
|
|
||||||
|
void init_utils(py::module_& m) {
|
||||||
|
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
|
||||||
|
A context manager for setting the current device and stream.
|
||||||
|
|
||||||
|
See :func:`stream` for usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: The stream or device to set as the default.
|
||||||
|
)pbdoc")
|
||||||
|
.def(py::init<StreamOrDevice>(), "s"_a)
|
||||||
|
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
||||||
|
.def(
|
||||||
|
"__exit__",
|
||||||
|
[](PyStreamContext& scm,
|
||||||
|
const std::optional<py::type>& exc_type,
|
||||||
|
const std::optional<py::object>& exc_value,
|
||||||
|
const std::optional<py::object>& traceback) { scm.exit(); });
|
||||||
|
m.def(
|
||||||
|
"stream",
|
||||||
|
[](StreamOrDevice s) { return PyStreamContext(s); },
|
||||||
|
"s"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Create a context manager to set the default device and stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: The :obj:`Stream` or :obj:`Device` to set as the default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A context manager that sets the default device and stream.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block::python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
# Create a context manager for the default device and stream.
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
# Operations here will use mx.cpu by default.
|
||||||
|
pass
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -38,6 +38,17 @@ class TestDevice(mlx_tests.MLXTestCase):
|
|||||||
# Restore device
|
# Restore device
|
||||||
mx.set_default_device(device)
|
mx.set_default_device(device)
|
||||||
|
|
||||||
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
|
def test_device_context(self):
|
||||||
|
default = mx.default_device()
|
||||||
|
diff = mx.cpu if default == mx.gpu else mx.gpu
|
||||||
|
self.assertNotEqual(default, diff)
|
||||||
|
with mx.stream(diff):
|
||||||
|
a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2)))
|
||||||
|
mx.eval(a)
|
||||||
|
self.assertEqual(mx.default_device(), diff)
|
||||||
|
self.assertEqual(mx.default_device(), default)
|
||||||
|
|
||||||
def test_op_on_device(self):
|
def test_op_on_device(self):
|
||||||
x = mx.array(1.0)
|
x = mx.array(1.0)
|
||||||
y = mx.array(1.0)
|
y = mx.array(1.0)
|
||||||
|
@ -19,72 +19,73 @@ class TestFFT(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||||
|
|
||||||
def test_fft(self):
|
def test_fft(self):
|
||||||
default = mx.default_device()
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
def check_mx_np(op_mx, op_np, a_np, **kwargs):
|
def check_mx_np(op_mx, op_np, a_np, **kwargs):
|
||||||
out_np = op_np(a_np, **kwargs)
|
out_np = op_np(a_np, **kwargs)
|
||||||
a_mx = mx.array(a_np)
|
a_mx = mx.array(a_np)
|
||||||
out_mx = op_mx(a_mx, **kwargs)
|
out_mx = op_mx(a_mx, **kwargs)
|
||||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||||
|
|
||||||
r = np.random.rand(100).astype(np.float32)
|
with mx.stream(mx.cpu):
|
||||||
i = np.random.rand(100).astype(np.float32)
|
r = np.random.rand(100).astype(np.float32)
|
||||||
a_np = r + 1j * i
|
i = np.random.rand(100).astype(np.float32)
|
||||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
a_np = r + 1j * i
|
||||||
|
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||||
|
|
||||||
# Check with slicing and padding
|
# Check with slicing and padding
|
||||||
r = np.random.rand(100).astype(np.float32)
|
r = np.random.rand(100).astype(np.float32)
|
||||||
i = np.random.rand(100).astype(np.float32)
|
i = np.random.rand(100).astype(np.float32)
|
||||||
a_np = r + 1j * i
|
a_np = r + 1j * i
|
||||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||||
|
|
||||||
# Check different axes
|
# Check different axes
|
||||||
r = np.random.rand(100, 100).astype(np.float32)
|
r = np.random.rand(100, 100).astype(np.float32)
|
||||||
i = np.random.rand(100, 100).astype(np.float32)
|
i = np.random.rand(100, 100).astype(np.float32)
|
||||||
a_np = r + 1j * i
|
a_np = r + 1j * i
|
||||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||||
|
|
||||||
# Check real fft
|
# Check real fft
|
||||||
a_np = np.random.rand(100).astype(np.float32)
|
a_np = np.random.rand(100).astype(np.float32)
|
||||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||||
|
|
||||||
# Check real inverse
|
# Check real inverse
|
||||||
r = np.random.rand(100, 100).astype(np.float32)
|
r = np.random.rand(100, 100).astype(np.float32)
|
||||||
i = np.random.rand(100, 100).astype(np.float32)
|
i = np.random.rand(100, 100).astype(np.float32)
|
||||||
a_np = r + 1j * i
|
a_np = r + 1j * i
|
||||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||||
|
|
||||||
mx.set_default_device(default)
|
|
||||||
|
|
||||||
def test_fftn(self):
|
def test_fftn(self):
|
||||||
default = mx.default_device()
|
with mx.stream(mx.cpu):
|
||||||
mx.set_default_device(mx.cpu)
|
r = np.random.randn(8, 8, 8).astype(np.float32)
|
||||||
|
i = np.random.randn(8, 8, 8).astype(np.float32)
|
||||||
|
a = r + 1j * i
|
||||||
|
|
||||||
r = np.random.randn(8, 8, 8).astype(np.float32)
|
axes = [None, (1, 2), (2, 1), (0, 2)]
|
||||||
i = np.random.randn(8, 8, 8).astype(np.float32)
|
shapes = [None, (10, 5), (5, 10)]
|
||||||
a = r + 1j * i
|
ops = [
|
||||||
|
"fft2",
|
||||||
|
"ifft2",
|
||||||
|
"rfft2",
|
||||||
|
"irfft2",
|
||||||
|
"fftn",
|
||||||
|
"ifftn",
|
||||||
|
"rfftn",
|
||||||
|
"irfftn",
|
||||||
|
]
|
||||||
|
|
||||||
axes = [None, (1, 2), (2, 1), (0, 2)]
|
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||||
shapes = [None, (10, 5), (5, 10)]
|
x = a
|
||||||
ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"]
|
if op in ["rfft2", "rfftn"]:
|
||||||
|
x = r
|
||||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
self.check_mx_np(op, x, axes=ax, s=s)
|
||||||
x = a
|
|
||||||
if op in ["rfft2", "rfftn"]:
|
|
||||||
x = r
|
|
||||||
self.check_mx_np(op, x, axes=ax, s=s)
|
|
||||||
|
|
||||||
mx.set_default_device(default)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user