2024-03-19 11:12:25 +08:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-12-01 03:12:53 +08:00
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
#include <sstream>
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
#include <nanobind/nanobind.h>
|
|
|
|
#include <nanobind/stl/string.h>
|
2023-11-30 02:42:59 +08:00
|
|
|
|
|
|
|
#include "mlx/device.h"
|
|
|
|
#include "mlx/utils.h"
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
namespace nb = nanobind;
|
|
|
|
using namespace nb::literals;
|
2023-11-30 02:42:59 +08:00
|
|
|
using namespace mlx::core;
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
void init_device(nb::module_& m) {
|
|
|
|
auto device_class = nb::class_<Device>(
|
2024-02-15 06:14:58 +08:00
|
|
|
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::enum_<Device::DeviceType>(m, "DeviceType")
|
2023-11-30 02:42:59 +08:00
|
|
|
.value("cpu", Device::DeviceType::cpu)
|
|
|
|
.value("gpu", Device::DeviceType::gpu)
|
|
|
|
.export_values()
|
2024-03-19 11:12:25 +08:00
|
|
|
.def("__eq__", [](const Device::DeviceType& d, const nb::object& other) {
|
|
|
|
if (!nb::isinstance<Device>(other) &&
|
|
|
|
!nb::isinstance<Device::DeviceType>(other)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return d == nb::cast<Device>(other);
|
|
|
|
});
|
|
|
|
|
|
|
|
device_class.def(nb::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
|
|
|
.def_ro("type", &Device::type)
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__repr__",
|
|
|
|
[](const Device& d) {
|
|
|
|
std::ostringstream os;
|
|
|
|
os << d;
|
|
|
|
return os.str();
|
|
|
|
})
|
2024-03-19 11:12:25 +08:00
|
|
|
.def("__eq__", [](const Device& d, const nb::object& other) {
|
|
|
|
if (!nb::isinstance<Device>(other) &&
|
|
|
|
!nb::isinstance<Device::DeviceType>(other)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return d == nb::cast<Device>(other);
|
2023-11-30 02:42:59 +08:00
|
|
|
});
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::implicitly_convertible<Device::DeviceType, Device>();
|
2023-11-30 02:42:59 +08:00
|
|
|
|
2024-02-15 06:14:58 +08:00
|
|
|
m.def(
|
|
|
|
"default_device",
|
|
|
|
&default_device,
|
|
|
|
R"pbdoc(Get the default device.)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"set_default_device",
|
|
|
|
&set_default_device,
|
|
|
|
"device"_a,
|
|
|
|
R"pbdoc(Set the default device.)pbdoc");
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|