mlx/python/src/metal.cpp
2025-05-11 06:03:47 -07:00

109 lines
3.0 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#include <iostream>
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/backend/metal/metal.h"
#include "mlx/memory.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
bool DEPRECATE(const std::string& old_fn, const std::string new_fn) {
std::cerr << old_fn << " is deprecated and will be removed in a future "
<< "version. Use " << new_fn << " instead." << std::endl;
return true;
}
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn)
void init_metal(nb::module_& m) {
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def(
"is_available",
&mx::metal::is_available,
R"pbdoc(
Check if the Metal back-end is available.
)pbdoc");
metal.def("get_active_memory", []() {
DEPRECATE("mx.metal.get_active_memory", "mx.get_active_memory");
return mx::get_active_memory();
});
metal.def("get_peak_memory", []() {
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory");
return mx::get_peak_memory();
});
metal.def("reset_peak_memory", []() {
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory");
mx::reset_peak_memory();
});
metal.def("get_cache_memory", []() {
DEPRECATE("mx.metal.get_cache_memory", "mx.get_cache_memory");
return mx::get_cache_memory();
});
metal.def(
"set_memory_limit",
[](size_t limit) {
DEPRECATE("mx.metal.set_memory_limit", "mx.set_memory_limit");
return mx::set_memory_limit(limit);
},
"limit"_a);
metal.def(
"set_cache_limit",
[](size_t limit) {
DEPRECATE("mx.metal.set_cache_limit", "mx.set_cache_limit");
return mx::set_cache_limit(limit);
},
"limit"_a);
metal.def(
"set_wired_limit",
[](size_t limit) {
DEPRECATE("mx.metal.set_wired_limit", "mx.set_wired_limit");
return mx::set_wired_limit(limit);
},
"limit"_a);
metal.def("clear_cache", []() {
DEPRECATE("mx.metal.clear_cache", "mx.clear_cache");
mx::clear_cache();
});
metal.def(
"start_capture",
&mx::metal::start_capture,
"path"_a,
R"pbdoc(
Start a Metal capture.
Args:
path (str): The path to save the capture which should have
the extension ``.gputrace``.
)pbdoc");
metal.def(
"stop_capture",
&mx::metal::stop_capture,
R"pbdoc(
Stop a Metal capture.
)pbdoc");
metal.def(
"device_info",
&mx::metal::device_info,
R"pbdoc(
Get information about the GPU device and system settings.
Currently returns:
* ``architecture``
* ``max_buffer_size``
* ``max_recommended_working_set_size``
* ``memory_size``
* ``resource_limit``
Returns:
dict: A dictionary with string keys and string or integer values.
)pbdoc");
}