mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
angelos's commit files
This commit is contained in:
1071
python/src/array.cpp
Normal file
1071
python/src/array.cpp
Normal file
File diff suppressed because it is too large
Load Diff
42
python/src/device.cpp
Normal file
42
python/src/device.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#include <sstream>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_device(py::module_& m) {
|
||||
py::enum_<Device::DeviceType>(m, "DeviceType")
|
||||
.value("cpu", Device::DeviceType::cpu)
|
||||
.value("gpu", Device::DeviceType::gpu)
|
||||
.export_values()
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const Device::DeviceType& d1, const Device& d2) {
|
||||
return d1 == d2;
|
||||
},
|
||||
py::prepend());
|
||||
|
||||
py::class_<Device>(m, "Device")
|
||||
.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||
.def_readonly("type", &Device::type)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Device& d) {
|
||||
std::ostringstream os;
|
||||
os << d;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Device& d1, const Device& d2) {
|
||||
return d1 == d2;
|
||||
});
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def("default_device", &default_device);
|
||||
m.def("set_default_device", &set_default_device, "device"_a);
|
||||
}
|
12
python/src/metal.cpp
Normal file
12
python/src/metal.cpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_metal(py::module_& m) {
|
||||
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
metal.def("is_available", &metal::is_available);
|
||||
}
|
32
python/src/stream.cpp
Normal file
32
python/src/stream.cpp
Normal file
@@ -0,0 +1,32 @@
|
||||
#include <sstream>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_stream(py::module_& m) {
|
||||
py::class_<Stream>(m, "Stream")
|
||||
.def(py::init<int, Device>(), "index"_a, "device"_a)
|
||||
.def_readonly("device", &Stream::device)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Stream& s) {
|
||||
std::ostringstream os;
|
||||
os << s;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Stream& s1, const Stream& s2) {
|
||||
return s1 == s2;
|
||||
});
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def("default_stream", &default_stream, "device"_a);
|
||||
m.def("set_default_stream", &set_default_stream, "stream"_a);
|
||||
m.def("new_stream", &new_stream, "device"_a);
|
||||
}
|
Reference in New Issue
Block a user