mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
109 lines
3.0 KiB
C++
109 lines
3.0 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
#include <iostream>
|
|
|
|
#include <nanobind/nanobind.h>
|
|
#include <nanobind/stl/optional.h>
|
|
#include <nanobind/stl/string.h>
|
|
#include <nanobind/stl/unordered_map.h>
|
|
#include <nanobind/stl/variant.h>
|
|
#include <nanobind/stl/vector.h>
|
|
#include "mlx/backend/metal/metal.h"
|
|
#include "mlx/memory.h"
|
|
|
|
namespace mx = mlx::core;
|
|
namespace nb = nanobind;
|
|
using namespace nb::literals;
|
|
|
|
bool DEPRECATE(const std::string& old_fn, const std::string new_fn) {
|
|
std::cerr << old_fn << " is deprecated and will be removed in a future "
|
|
<< "version. Use " << new_fn << " instead." << std::endl;
|
|
return true;
|
|
}
|
|
|
|
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn)
|
|
|
|
void init_metal(nb::module_& m) {
|
|
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
|
metal.def(
|
|
"is_available",
|
|
&mx::metal::is_available,
|
|
R"pbdoc(
|
|
Check if the Metal back-end is available.
|
|
)pbdoc");
|
|
metal.def("get_active_memory", []() {
|
|
DEPRECATE("mx.metal.get_active_memory", "mx.get_active_memory");
|
|
return mx::get_active_memory();
|
|
});
|
|
metal.def("get_peak_memory", []() {
|
|
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory");
|
|
return mx::get_peak_memory();
|
|
});
|
|
metal.def("reset_peak_memory", []() {
|
|
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory");
|
|
mx::reset_peak_memory();
|
|
});
|
|
metal.def("get_cache_memory", []() {
|
|
DEPRECATE("mx.metal.get_cache_memory", "mx.get_cache_memory");
|
|
return mx::get_cache_memory();
|
|
});
|
|
metal.def(
|
|
"set_memory_limit",
|
|
[](size_t limit) {
|
|
DEPRECATE("mx.metal.set_memory_limit", "mx.set_memory_limit");
|
|
return mx::set_memory_limit(limit);
|
|
},
|
|
"limit"_a);
|
|
metal.def(
|
|
"set_cache_limit",
|
|
[](size_t limit) {
|
|
DEPRECATE("mx.metal.set_cache_limit", "mx.set_cache_limit");
|
|
return mx::set_cache_limit(limit);
|
|
},
|
|
"limit"_a);
|
|
metal.def(
|
|
"set_wired_limit",
|
|
[](size_t limit) {
|
|
DEPRECATE("mx.metal.set_wired_limit", "mx.set_wired_limit");
|
|
return mx::set_wired_limit(limit);
|
|
},
|
|
"limit"_a);
|
|
metal.def("clear_cache", []() {
|
|
DEPRECATE("mx.metal.clear_cache", "mx.clear_cache");
|
|
mx::clear_cache();
|
|
});
|
|
metal.def(
|
|
"start_capture",
|
|
&mx::metal::start_capture,
|
|
"path"_a,
|
|
R"pbdoc(
|
|
Start a Metal capture.
|
|
|
|
Args:
|
|
path (str): The path to save the capture which should have
|
|
the extension ``.gputrace``.
|
|
)pbdoc");
|
|
metal.def(
|
|
"stop_capture",
|
|
&mx::metal::stop_capture,
|
|
R"pbdoc(
|
|
Stop a Metal capture.
|
|
)pbdoc");
|
|
metal.def(
|
|
"device_info",
|
|
&mx::metal::device_info,
|
|
R"pbdoc(
|
|
Get information about the GPU device and system settings.
|
|
|
|
Currently returns:
|
|
|
|
* ``architecture``
|
|
* ``max_buffer_size``
|
|
* ``max_recommended_working_set_size``
|
|
* ``memory_size``
|
|
* ``resource_limit``
|
|
|
|
Returns:
|
|
dict: A dictionary with string keys and string or integer values.
|
|
)pbdoc");
|
|
}
|