angelos's commit files

This commit is contained in:
Angelos Katharopoulos
2023-11-29 10:42:59 -08:00
parent 8ca7f9e8e9
commit d1f86272a2
56 changed files with 12350 additions and 0 deletions

1071
python/src/array.cpp Normal file

File diff suppressed because it is too large Load Diff

42
python/src/device.cpp Normal file
View File

@@ -0,0 +1,42 @@
#include <sstream>
#include <pybind11/pybind11.h>
#include "mlx/device.h"
#include "mlx/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_device(py::module_& m) {
py::enum_<Device::DeviceType>(m, "DeviceType")
.value("cpu", Device::DeviceType::cpu)
.value("gpu", Device::DeviceType::gpu)
.export_values()
.def(
"__eq__",
[](const Device::DeviceType& d1, const Device& d2) {
return d1 == d2;
},
py::prepend());
py::class_<Device>(m, "Device")
.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_readonly("type", &Device::type)
.def(
"__repr__",
[](const Device& d) {
std::ostringstream os;
os << d;
return os.str();
})
.def("__eq__", [](const Device& d1, const Device& d2) {
return d1 == d2;
});
py::implicitly_convertible<Device::DeviceType, Device>();
m.def("default_device", &default_device);
m.def("set_default_device", &set_default_device, "device"_a);
}

12
python/src/metal.cpp Normal file
View File

@@ -0,0 +1,12 @@
#include <pybind11/pybind11.h>
#include "mlx/backend/metal/metal.h"
namespace py = pybind11;
using namespace mlx::core;
void init_metal(py::module_& m) {
py::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def("is_available", &metal::is_available);
}

32
python/src/stream.cpp Normal file
View File

@@ -0,0 +1,32 @@
#include <sstream>
#include <pybind11/pybind11.h>
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_stream(py::module_& m) {
py::class_<Stream>(m, "Stream")
.def(py::init<int, Device>(), "index"_a, "device"_a)
.def_readonly("device", &Stream::device)
.def(
"__repr__",
[](const Stream& s) {
std::ostringstream os;
os << s;
return os.str();
})
.def("__eq__", [](const Stream& s1, const Stream& s2) {
return s1 == s2;
});
py::implicitly_convertible<Device::DeviceType, Device>();
m.def("default_stream", &default_stream, "device"_a);
m.def("set_default_stream", &set_default_stream, "stream"_a);
m.def("new_stream", &new_stream, "device"_a);
}