// Copyright © 2023 Apple Inc. #include #include #include "mlx/device.h" #include "mlx/utils.h" namespace py = pybind11; using namespace py::literals; using namespace mlx::core; void init_device(py::module_& m) { auto device_class = py::class_( m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); py::enum_(m, "DeviceType") .value("cpu", Device::DeviceType::cpu) .value("gpu", Device::DeviceType::gpu) .export_values() .def( "__eq__", [](const Device::DeviceType& d1, const Device& d2) { return d1 == d2; }, py::prepend()); device_class.def(py::init(), "type"_a, "index"_a = 0) .def_readonly("type", &Device::type) .def( "__repr__", [](const Device& d) { std::ostringstream os; os << d; return os.str(); }) .def("__eq__", [](const Device& d1, const Device& d2) { return d1 == d2; }); py::implicitly_convertible(); m.def( "default_device", &default_device, R"pbdoc(Get the default device.)pbdoc"); m.def( "set_default_device", &set_default_device, "device"_a, R"pbdoc(Set the default device.)pbdoc"); }