mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Make array conform to the Python Buffer Protocol (#323)
This commit is contained in:
parent
dfdb284e16
commit
1331fa19f6
@ -62,6 +62,7 @@ jobs:
|
|||||||
pip install --upgrade pybind11[global]
|
pip install --upgrade pybind11[global]
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install torch
|
pip install torch
|
||||||
|
pip install tensorflow
|
||||||
pip install unittest-xml-reporting
|
pip install unittest-xml-reporting
|
||||||
- run:
|
- run:
|
||||||
name: Build python package
|
name: Build python package
|
||||||
|
@ -61,7 +61,7 @@ variety of examples, including:
|
|||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
See the [quick start
|
See the [quick start
|
||||||
guide](https://ml-explore.github.io/mlx/build/html/quick_start.html)
|
guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
|
||||||
in the documentation.
|
in the documentation.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
@ -35,9 +35,10 @@ are the CPU and GPU.
|
|||||||
:caption: Usage
|
:caption: Usage
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
quick_start
|
usage/quick_start
|
||||||
unified_memory
|
usage/unified_memory
|
||||||
using_streams
|
usage/using_streams
|
||||||
|
usage/numpy
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:caption: Examples
|
:caption: Examples
|
||||||
|
103
docs/src/usage/numpy.rst
Normal file
103
docs/src/usage/numpy.rst
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
.. _numpy:
|
||||||
|
|
||||||
|
Conversion to NumPy and Other Frameworks
|
||||||
|
========================================
|
||||||
|
|
||||||
|
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||||
|
Let's convert an array to NumPy and back.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = np.array(a) # copy of a
|
||||||
|
c = mx.array(b) # copy of b
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
||||||
|
``np.array(a.astype(mx.float32))``.
|
||||||
|
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
||||||
|
|
||||||
|
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
a_view = np.array(a, copy=False)
|
||||||
|
print(a_view.flags.owndata) # False
|
||||||
|
a_view[0] = 1
|
||||||
|
print(a[0].item()) # 1
|
||||||
|
|
||||||
|
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
||||||
|
This means writing to the view is reflected in the original array.
|
||||||
|
|
||||||
|
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
||||||
|
|
||||||
|
Let's demonstrate this in an example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
x_view = np.array(x, copy=False)
|
||||||
|
x_view[:] *= x_view # modify memory without telling mx
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
x = mx.array([3.0])
|
||||||
|
y, df = mx.value_and_grad(f)(x)
|
||||||
|
print("f(x) = x² =", y.item()) # 9.0
|
||||||
|
print("f'(x) = 2x !=", df.item()) # 1.0
|
||||||
|
|
||||||
|
|
||||||
|
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||||
|
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
||||||
|
representing the gradient of the sum operation alone.
|
||||||
|
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
||||||
|
It's important to note that a similar issue arises during array conversion and copying.
|
||||||
|
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||||
|
even though no in-place operations on MLX memory are executed.
|
||||||
|
|
||||||
|
PyTorch
|
||||||
|
-------
|
||||||
|
|
||||||
|
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = torch.tensor(memoryview(a))
|
||||||
|
c = mx.array(b.numpy())
|
||||||
|
|
||||||
|
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
||||||
|
|
||||||
|
JAX
|
||||||
|
---
|
||||||
|
JAX fully supports the buffer protocol.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = jnp.array(a)
|
||||||
|
c = mx.array(b)
|
||||||
|
|
||||||
|
TensorFlow
|
||||||
|
----------
|
||||||
|
|
||||||
|
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = tf.constant(memoryview(a))
|
||||||
|
c = mx.array(b)
|
@ -278,108 +278,6 @@ array array_from_list(T pl, std::optional<Dtype> dtype) {
|
|||||||
return stack(arrays);
|
return stack(arrays);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MLX -> Numpy
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
size_t elem_to_loc(
|
|
||||||
int elem,
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<size_t>& strides) {
|
|
||||||
size_t loc = 0;
|
|
||||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
|
||||||
auto q_and_r = ldiv(elem, shape[i]);
|
|
||||||
loc += q_and_r.rem * strides[i];
|
|
||||||
elem = q_and_r.quot;
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct PyArrayPayload {
|
|
||||||
array a;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
py::array_t<T> mlx_array_to_np_t(const array& src) {
|
|
||||||
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
|
|
||||||
// the data
|
|
||||||
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
|
|
||||||
delete reinterpret_cast<PyArrayPayload*>(payload);
|
|
||||||
});
|
|
||||||
// Collect strides
|
|
||||||
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
|
|
||||||
for (int i = 0; i < src.ndim(); i++) {
|
|
||||||
strides[i] *= src.itemsize();
|
|
||||||
}
|
|
||||||
// Pack the capsule with the array
|
|
||||||
py::array_t<T> out(src.shape(), strides, src.data<T>(), freeWhenDone);
|
|
||||||
// Mark array as read-only
|
|
||||||
py::detail::array_proxy(out.ptr())->flags &=
|
|
||||||
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
|
||||||
// Return array
|
|
||||||
return py::array_t(src.shape(), strides, src.data<T>(), out);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
py::array mlx_array_to_np_t(const array& src, const py::dtype& dt) {
|
|
||||||
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
|
|
||||||
// the data
|
|
||||||
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
|
|
||||||
delete reinterpret_cast<PyArrayPayload*>(payload);
|
|
||||||
});
|
|
||||||
// Collect strides
|
|
||||||
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
|
|
||||||
for (int i = 0; i < src.ndim(); i++) {
|
|
||||||
strides[i] *= src.itemsize();
|
|
||||||
}
|
|
||||||
// Pack the capsule with the array
|
|
||||||
py::array out(dt, src.shape(), strides, src.data<T>(), freeWhenDone);
|
|
||||||
// Mark array as read-only
|
|
||||||
py::detail::array_proxy(out.ptr())->flags &=
|
|
||||||
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
|
||||||
// Return array
|
|
||||||
return py::array(dt, src.shape(), strides, src.data<T>(), out);
|
|
||||||
}
|
|
||||||
|
|
||||||
py::array mlx_array_to_np(const array& src) {
|
|
||||||
// Eval if not already evaled
|
|
||||||
if (!src.is_evaled()) {
|
|
||||||
eval({src}, src.is_tracer());
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (src.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
return mlx_array_to_np_t<bool>(src);
|
|
||||||
case uint8:
|
|
||||||
return mlx_array_to_np_t<uint8_t>(src);
|
|
||||||
case uint16:
|
|
||||||
return mlx_array_to_np_t<uint16_t>(src);
|
|
||||||
case uint32:
|
|
||||||
return mlx_array_to_np_t<uint32_t>(src);
|
|
||||||
case uint64:
|
|
||||||
return mlx_array_to_np_t<uint64_t>(src);
|
|
||||||
case int8:
|
|
||||||
return mlx_array_to_np_t<int8_t>(src);
|
|
||||||
case int16:
|
|
||||||
return mlx_array_to_np_t<int16_t>(src);
|
|
||||||
case int32:
|
|
||||||
return mlx_array_to_np_t<int32_t>(src);
|
|
||||||
case int64:
|
|
||||||
return mlx_array_to_np_t<int64_t>(src);
|
|
||||||
case float16:
|
|
||||||
return mlx_array_to_np_t<float16_t>(src, py::dtype("float16"));
|
|
||||||
case float32:
|
|
||||||
return mlx_array_to_np_t<float>(src);
|
|
||||||
case bfloat16: {
|
|
||||||
auto a = astype(src, float32);
|
|
||||||
eval({a}, src.is_tracer());
|
|
||||||
return mlx_array_to_np_t<float>(a);
|
|
||||||
}
|
|
||||||
case complex64:
|
|
||||||
return mlx_array_to_np_t<complex64_t>(src, py::dtype("complex64"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Numpy -> MLX
|
// Numpy -> MLX
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -479,6 +377,61 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Python Buffer Protocol (MLX -> Numpy)
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
std::optional<std::string> buffer_format(const array& a) {
|
||||||
|
// https://docs.python.org/3.10/library/struct.html#format-characters
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
return pybind11::format_descriptor<bool>::format();
|
||||||
|
case uint8:
|
||||||
|
return pybind11::format_descriptor<uint8_t>::format();
|
||||||
|
case uint16:
|
||||||
|
return pybind11::format_descriptor<uint16_t>::format();
|
||||||
|
case uint32:
|
||||||
|
return pybind11::format_descriptor<uint32_t>::format();
|
||||||
|
case uint64:
|
||||||
|
return pybind11::format_descriptor<uint64_t>::format();
|
||||||
|
case int8:
|
||||||
|
return pybind11::format_descriptor<int8_t>::format();
|
||||||
|
case int16:
|
||||||
|
return pybind11::format_descriptor<int16_t>::format();
|
||||||
|
case int32:
|
||||||
|
return pybind11::format_descriptor<int32_t>::format();
|
||||||
|
case int64:
|
||||||
|
return pybind11::format_descriptor<int64_t>::format();
|
||||||
|
case float16:
|
||||||
|
// https://github.com/pybind/pybind11/issues/4998
|
||||||
|
return "e";
|
||||||
|
case float32: {
|
||||||
|
return pybind11::format_descriptor<float>::format();
|
||||||
|
}
|
||||||
|
case bfloat16:
|
||||||
|
// not supported by python buffer protocol or numpy.
|
||||||
|
// musst be null according to
|
||||||
|
// https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT
|
||||||
|
return {};
|
||||||
|
case complex64:
|
||||||
|
return pybind11::format_descriptor<std::complex<float>>::format();
|
||||||
|
default: {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "bad dtype: " << a.dtype();
|
||||||
|
throw std::runtime_error(os.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> buffer_strides(const array& a) {
|
||||||
|
std::vector<size_t> py_strides;
|
||||||
|
py_strides.reserve(a.strides().size());
|
||||||
|
for (const size_t stride : a.strides()) {
|
||||||
|
py_strides.push_back(stride * a.itemsize());
|
||||||
|
}
|
||||||
|
return py_strides;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Module
|
// Module
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -546,7 +499,10 @@ void init_array(py::module_& m) {
|
|||||||
m.attr("complex64") = py::cast(complex64);
|
m.attr("complex64") = py::cast(complex64);
|
||||||
|
|
||||||
auto array_class = py::class_<array>(
|
auto array_class = py::class_<array>(
|
||||||
m, "array", R"pbdoc(An N-dimensional array object.)pbdoc");
|
m,
|
||||||
|
"array",
|
||||||
|
R"pbdoc(An N-dimensional array object.)pbdoc",
|
||||||
|
py::buffer_protocol());
|
||||||
|
|
||||||
{
|
{
|
||||||
py::options options;
|
py::options options;
|
||||||
@ -564,6 +520,19 @@ void init_array(py::module_& m) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array_class
|
array_class
|
||||||
|
.def_buffer([](array& a) {
|
||||||
|
// Eval if not already evaled
|
||||||
|
if (!a.is_evaled()) {
|
||||||
|
eval({a}, a.is_tracer());
|
||||||
|
}
|
||||||
|
return pybind11::buffer_info(
|
||||||
|
a.data<void>(),
|
||||||
|
a.itemsize(),
|
||||||
|
buffer_format(a).value_or(nullptr),
|
||||||
|
a.ndim(),
|
||||||
|
a.shape(),
|
||||||
|
buffer_strides(a));
|
||||||
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
@ -620,7 +589,6 @@ void init_array(py::module_& m) {
|
|||||||
The value type of the list corresponding to the last dimension is either
|
The value type of the list corresponding to the last dimension is either
|
||||||
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
|
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def("__array__", &mlx_array_to_np)
|
|
||||||
.def(
|
.def(
|
||||||
"astype",
|
"astype",
|
||||||
&astype,
|
&astype,
|
||||||
|
@ -2,12 +2,20 @@
|
|||||||
|
|
||||||
import operator
|
import operator
|
||||||
import unittest
|
import unittest
|
||||||
|
import weakref
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
has_tf = True
|
||||||
|
except ImportError as e:
|
||||||
|
has_tf = False
|
||||||
|
|
||||||
|
|
||||||
class TestVersion(mlx_tests.MLXTestCase):
|
class TestVersion(mlx_tests.MLXTestCase):
|
||||||
def test_version(self):
|
def test_version(self):
|
||||||
@ -1100,7 +1108,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
# Check that we get read-only array that does not own the underlying data
|
# Check that we get read-only array that does not own the underlying data
|
||||||
self.assertFalse(a_np.flags.owndata)
|
self.assertFalse(a_np.flags.owndata)
|
||||||
self.assertFalse(a_np.flags.writeable)
|
self.assertTrue(a_np.flags.writeable)
|
||||||
|
|
||||||
# Check contents
|
# Check contents
|
||||||
self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np))
|
self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np))
|
||||||
@ -1109,6 +1117,157 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
# Check strides
|
# Check strides
|
||||||
self.assertSequenceEqual(b_np.strides, (0, 8, 4))
|
self.assertSequenceEqual(b_np.strides, (0, 8, 4))
|
||||||
|
|
||||||
|
def test_np_array_conversion_copies_by_default(self):
|
||||||
|
a_mx = mx.ones((2, 2))
|
||||||
|
a_np = np.array(a_mx)
|
||||||
|
self.assertTrue(a_np.flags.owndata)
|
||||||
|
self.assertTrue(a_np.flags.writeable)
|
||||||
|
|
||||||
|
def test_buffer_protocol(self):
|
||||||
|
dtypes_list = [
|
||||||
|
(mx.bool_, np.bool_, None),
|
||||||
|
(mx.uint8, np.uint8, np.iinfo),
|
||||||
|
(mx.uint16, np.uint16, np.iinfo),
|
||||||
|
(mx.uint32, np.uint32, np.iinfo),
|
||||||
|
(mx.uint64, np.uint64, np.iinfo),
|
||||||
|
(mx.int8, np.int8, np.iinfo),
|
||||||
|
(mx.int16, np.int16, np.iinfo),
|
||||||
|
(mx.int32, np.int32, np.iinfo),
|
||||||
|
(mx.int64, np.int64, np.iinfo),
|
||||||
|
(mx.float16, np.float16, np.finfo),
|
||||||
|
(mx.float32, np.float32, np.finfo),
|
||||||
|
(mx.complex64, np.complex64, np.finfo),
|
||||||
|
]
|
||||||
|
|
||||||
|
for mlx_dtype, np_dtype, info_fn in dtypes_list:
|
||||||
|
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
|
||||||
|
if info_fn is not None:
|
||||||
|
info = info_fn(np_dtype)
|
||||||
|
a_np[0, 0] = info.min
|
||||||
|
a_np[0, 1] = info.max
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
for f in [lambda x: x, lambda x: x.T]:
|
||||||
|
mv_mx = memoryview(f(a_mx))
|
||||||
|
mv_np = memoryview(f(a_np))
|
||||||
|
self.assertEqual(mv_mx.strides, mv_np.strides, f"{mlx_dtype}{np_dtype}")
|
||||||
|
self.assertEqual(mv_mx.shape, mv_np.shape, f"{mlx_dtype}{np_dtype}")
|
||||||
|
# correct buffer format for 8 byte (unsigned) 'long long' is Q/q, see
|
||||||
|
# https://docs.python.org/3.10/library/struct.html#format-characters
|
||||||
|
# numpy returns L/l, as 'long' is equivalent to 'long long' on 64bit machines, so q and l are equivalent
|
||||||
|
# see https://github.com/pybind/pybind11/issues/1908
|
||||||
|
if np_dtype == np.uint64:
|
||||||
|
self.assertEqual(mv_mx.format, "Q", f"{mlx_dtype}{np_dtype}")
|
||||||
|
elif np_dtype == np.int64:
|
||||||
|
self.assertEqual(mv_mx.format, "q", f"{mlx_dtype}{np_dtype}")
|
||||||
|
else:
|
||||||
|
self.assertEqual(
|
||||||
|
mv_mx.format, mv_np.format, f"{mlx_dtype}{np_dtype}"
|
||||||
|
)
|
||||||
|
self.assertFalse(mv_mx.readonly)
|
||||||
|
back_to_npy = np.array(mv_mx, copy=False)
|
||||||
|
self.assertEqualArray(
|
||||||
|
back_to_npy,
|
||||||
|
f(a_np),
|
||||||
|
atol=0,
|
||||||
|
rtol=0,
|
||||||
|
msg=f"{mlx_dtype}{np_dtype}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# extra test for bfloat16, which is not numpy convertible
|
||||||
|
a_mx = mx.random.uniform(low=0, high=100, shape=(3, 4), dtype=mx.bfloat16)
|
||||||
|
mv_mx = memoryview(a_mx)
|
||||||
|
self.assertEqual(mv_mx.strides, (8, 2))
|
||||||
|
self.assertEqual(mv_mx.shape, (3, 4))
|
||||||
|
self.assertEqual(mv_mx.format, "")
|
||||||
|
with self.assertRaises(RuntimeError) as cm:
|
||||||
|
np.array(a_mx)
|
||||||
|
e = cm.exception
|
||||||
|
self.assertTrue("Item size 2 for PEP 3118 buffer format string" in str(e))
|
||||||
|
|
||||||
|
def test_buffer_protocol_ref_counting(self):
|
||||||
|
a = mx.arange(3)
|
||||||
|
wr = weakref.ref(a)
|
||||||
|
self.assertIsNotNone(wr())
|
||||||
|
mv = memoryview(a)
|
||||||
|
a = None
|
||||||
|
self.assertIsNotNone(wr())
|
||||||
|
mv = None
|
||||||
|
self.assertIsNone(wr())
|
||||||
|
|
||||||
|
def test_array_view_ref_counting(self):
|
||||||
|
a = mx.arange(3)
|
||||||
|
wr = weakref.ref(a)
|
||||||
|
self.assertIsNotNone(wr())
|
||||||
|
a_np = np.array(a, copy=False)
|
||||||
|
a = None
|
||||||
|
self.assertIsNotNone(wr())
|
||||||
|
a_np = None
|
||||||
|
self.assertIsNone(wr())
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_tf, "requires TensorFlow")
|
||||||
|
def test_buffer_protocol_tf(self):
|
||||||
|
dtypes_list = [
|
||||||
|
(
|
||||||
|
mx.bool_,
|
||||||
|
tf.bool,
|
||||||
|
np.bool_,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
mx.uint8,
|
||||||
|
tf.uint8,
|
||||||
|
np.uint8,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
mx.uint16,
|
||||||
|
tf.uint16,
|
||||||
|
np.uint16,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
mx.uint32,
|
||||||
|
tf.uint32,
|
||||||
|
np.uint32,
|
||||||
|
),
|
||||||
|
(mx.uint64, tf.uint64, np.uint64),
|
||||||
|
(mx.int8, tf.int8, np.int8),
|
||||||
|
(mx.int16, tf.int16, np.int16),
|
||||||
|
(mx.int32, tf.int32, np.int32),
|
||||||
|
(mx.int64, tf.int64, np.int64),
|
||||||
|
(mx.float16, tf.float16, np.float16),
|
||||||
|
(mx.float32, tf.float32, np.float32),
|
||||||
|
(
|
||||||
|
mx.complex64,
|
||||||
|
tf.complex64,
|
||||||
|
np.complex64,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
|
||||||
|
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
|
||||||
|
a_tf = tf.constant(a_np, dtype=tf_dtype)
|
||||||
|
a_mx = mx.array(a_tf)
|
||||||
|
for f in [
|
||||||
|
lambda x: x,
|
||||||
|
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,
|
||||||
|
]:
|
||||||
|
mv_mx = memoryview(f(a_mx))
|
||||||
|
mv_tf = memoryview(f(a_tf))
|
||||||
|
if (mv_mx.c_contiguous and mv_tf.c_contiguous) or (
|
||||||
|
mv_mx.f_contiguous and mv_tf.f_contiguous
|
||||||
|
):
|
||||||
|
self.assertEqual(
|
||||||
|
mv_mx.strides, mv_tf.strides, f"{mlx_dtype}{tf_dtype}"
|
||||||
|
)
|
||||||
|
self.assertEqual(mv_mx.shape, mv_tf.shape, f"{mlx_dtype}{tf_dtype}")
|
||||||
|
self.assertFalse(mv_mx.readonly)
|
||||||
|
back_to_npy = np.array(mv_mx)
|
||||||
|
self.assertEqualArray(
|
||||||
|
back_to_npy,
|
||||||
|
f(a_tf),
|
||||||
|
atol=0,
|
||||||
|
rtol=0,
|
||||||
|
msg=f"{mlx_dtype}{tf_dtype}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -222,6 +222,8 @@ TEST_CASE("test array types") {
|
|||||||
|
|
||||||
// complex64
|
// complex64
|
||||||
{
|
{
|
||||||
|
CHECK_EQ(sizeof(complex64_t), sizeof(std::complex<float>));
|
||||||
|
|
||||||
complex64_t v = {1.0f, 1.0f};
|
complex64_t v = {1.0f, 1.0f};
|
||||||
array x(v);
|
array x(v);
|
||||||
CHECK_EQ(x.dtype(), complex64);
|
CHECK_EQ(x.dtype(), complex64);
|
||||||
|
Loading…
Reference in New Issue
Block a user