2024-03-19 11:12:25 +08:00
|
|
|
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <optional>
|
|
|
|
|
|
|
|
#include <nanobind/nanobind.h>
|
|
|
|
|
|
|
|
#include "mlx/array.h"
|
|
|
|
#include "mlx/utils.h"
|
|
|
|
|
|
|
|
// Only defined in >= Python 3.9
|
|
|
|
// https://github.com/python/cpython/blob/f6cdc6b4a191b75027de342aa8b5d344fb31313e/Include/typeslots.h#L2-L3
|
|
|
|
#ifndef Py_bf_getbuffer
|
|
|
|
#define Py_bf_getbuffer 1
|
|
|
|
#define Py_bf_releasebuffer 2
|
|
|
|
#endif
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
namespace mx = mlx::core;
|
2024-03-19 11:12:25 +08:00
|
|
|
namespace nb = nanobind;
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
std::string buffer_format(const mx::array& a) {
|
2024-03-19 11:12:25 +08:00
|
|
|
// https://docs.python.org/3.10/library/struct.html#format-characters
|
|
|
|
switch (a.dtype()) {
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::bool_:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "?";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::uint8:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "B";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::uint16:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "H";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::uint32:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "I";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::uint64:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "Q";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::int8:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "b";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::int16:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "h";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::int32:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "i";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::int64:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "q";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::float16:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "e";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::float32:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "f";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::bfloat16:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "B";
|
2024-12-12 07:45:39 +08:00
|
|
|
case mx::complex64:
|
2024-03-19 11:12:25 +08:00
|
|
|
return "Zf\0";
|
|
|
|
default: {
|
|
|
|
std::ostringstream os;
|
|
|
|
os << "bad dtype: " << a.dtype();
|
|
|
|
throw std::runtime_error(os.str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
struct buffer_info {
|
|
|
|
std::string format;
|
2024-12-10 04:59:19 +08:00
|
|
|
std::vector<Py_ssize_t> shape;
|
|
|
|
std::vector<Py_ssize_t> strides;
|
2024-03-19 11:12:25 +08:00
|
|
|
|
|
|
|
buffer_info(
|
2024-03-29 04:14:59 +08:00
|
|
|
std::string format,
|
2024-12-10 04:59:19 +08:00
|
|
|
std::vector<Py_ssize_t> shape_in,
|
|
|
|
std::vector<Py_ssize_t> strides_in)
|
2024-03-29 04:14:59 +08:00
|
|
|
: format(std::move(format)),
|
2024-03-19 11:12:25 +08:00
|
|
|
shape(std::move(shape_in)),
|
|
|
|
strides(std::move(strides_in)) {}
|
|
|
|
|
|
|
|
buffer_info(const buffer_info&) = delete;
|
|
|
|
buffer_info& operator=(const buffer_info&) = delete;
|
|
|
|
|
|
|
|
buffer_info(buffer_info&& other) noexcept {
|
|
|
|
(*this) = std::move(other);
|
|
|
|
}
|
|
|
|
|
|
|
|
buffer_info& operator=(buffer_info&& rhs) noexcept {
|
|
|
|
format = std::move(rhs.format);
|
|
|
|
shape = std::move(rhs.shape);
|
|
|
|
strides = std::move(rhs.strides);
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
|
|
|
|
std::memset(view, 0, sizeof(Py_buffer));
|
2024-12-12 07:45:39 +08:00
|
|
|
auto a = nb::cast<mx::array>(nb::handle(obj));
|
2024-03-19 11:12:25 +08:00
|
|
|
|
2024-04-17 21:16:02 +08:00
|
|
|
{
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::gil_scoped_release nogil;
|
|
|
|
a.eval();
|
|
|
|
}
|
|
|
|
|
2024-12-10 04:59:19 +08:00
|
|
|
std::vector<Py_ssize_t> shape(a.shape().begin(), a.shape().end());
|
|
|
|
std::vector<Py_ssize_t> strides(a.strides().begin(), a.strides().end());
|
2024-03-19 11:12:25 +08:00
|
|
|
for (auto& s : strides) {
|
|
|
|
s *= a.itemsize();
|
|
|
|
}
|
|
|
|
buffer_info* info =
|
|
|
|
new buffer_info(buffer_format(a), std::move(shape), std::move(strides));
|
|
|
|
|
|
|
|
view->obj = obj;
|
|
|
|
view->ndim = a.ndim();
|
|
|
|
view->internal = info;
|
|
|
|
view->buf = a.data<void>();
|
|
|
|
view->itemsize = a.itemsize();
|
2024-04-19 21:06:13 +08:00
|
|
|
view->len = a.nbytes();
|
2024-03-19 11:12:25 +08:00
|
|
|
view->readonly = false;
|
|
|
|
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
|
|
|
|
view->format = const_cast<char*>(info->format.c_str());
|
|
|
|
}
|
|
|
|
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
|
|
|
|
view->strides = info->strides.data();
|
|
|
|
view->shape = info->shape.data();
|
|
|
|
}
|
|
|
|
Py_INCREF(view->obj);
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
extern "C" inline void releasebuffer(PyObject*, Py_buffer* view) {
|
|
|
|
delete (buffer_info*)view->internal;
|
|
|
|
}
|