mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 18:39:45 +08:00
Fix buffer protocol buffer size designation (#1010)
This commit is contained in:
parent
090ff659dc
commit
ef5f7d1aea
@ -104,7 +104,7 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
|
|||||||
view->internal = info;
|
view->internal = info;
|
||||||
view->buf = a.data<void>();
|
view->buf = a.data<void>();
|
||||||
view->itemsize = a.itemsize();
|
view->itemsize = a.itemsize();
|
||||||
view->len = a.size();
|
view->len = a.nbytes();
|
||||||
view->readonly = false;
|
view->readonly = false;
|
||||||
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
|
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
|
||||||
view->format = const_cast<char*>(info->format.c_str());
|
view->format = const_cast<char*>(info->format.c_str());
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import operator
|
import operator
|
||||||
import pickle
|
import pickle
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
import weakref
|
import weakref
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
@ -1497,6 +1498,17 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
e = cm.exception
|
e = cm.exception
|
||||||
self.assertTrue("Item size 2 for PEP 3118 buffer format string" in str(e))
|
self.assertTrue("Item size 2 for PEP 3118 buffer format string" in str(e))
|
||||||
|
|
||||||
|
# Test buffer protocol with non-arrays ie bytes
|
||||||
|
a = ord("a") * 257 + mx.arange(10).astype(mx.int16)
|
||||||
|
ab = bytes(a)
|
||||||
|
self.assertEqual(len(ab), 20)
|
||||||
|
if sys.byteorder == "little":
|
||||||
|
self.assertEqual(b"aaaaaaaaaa", ab[1::2])
|
||||||
|
self.assertEqual(b"abcdefghij", ab[::2])
|
||||||
|
else:
|
||||||
|
self.assertEqual(b"aaaaaaaaaa", ab[::2])
|
||||||
|
self.assertEqual(b"abcdefghij", ab[1::2])
|
||||||
|
|
||||||
def test_buffer_protocol_ref_counting(self):
|
def test_buffer_protocol_ref_counting(self):
|
||||||
a = mx.arange(3)
|
a = mx.arange(3)
|
||||||
wr = weakref.ref(a)
|
wr = weakref.ref(a)
|
||||||
|
Loading…
Reference in New Issue
Block a user