diff --git a/docs/src/python/metal.rst b/docs/src/python/metal.rst index 589ec0a82..d333e09ca 100644 --- a/docs/src/python/metal.rst +++ b/docs/src/python/metal.rst @@ -7,6 +7,7 @@ Metal :toctree: _autosummary is_available + device_info get_active_memory get_peak_memory get_cache_memory diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 1c1934b83..059e12b01 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -556,4 +556,15 @@ void new_stream(Stream stream) { } } +std::unordered_map> +device_info() { + auto raw_device = device(default_device()).mtl_device(); + auto arch = std::string(raw_device->architecture()->name()->utf8String()); + return { + {"architecture", arch}, + {"max_buffer_length", raw_device->maxBufferLength()}, + {"max_recommended_working_set_size", + raw_device->recommendedMaxWorkingSetSize()}}; +} + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index 63e4bff5e..c4d51dd17 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "mlx/array.h" namespace mlx::core::metal { @@ -61,4 +63,8 @@ void clear_cache(); void start_capture(std::string path = ""); void stop_capture(); +/** Get information about the GPU and system settings. */ +std::unordered_map> +device_info(); + } // namespace mlx::core::metal diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index d3c011397..0aeda5c9d 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -49,4 +49,10 @@ void start_capture(std::string path) {} void stop_capture() {} void clear_cache() {} +std::unordered_map> +device_info() { + throw std::runtime_error( + "[metal::device_info] Cannot get device info without metal backend"); +}; + } // namespace mlx::core::metal diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 0c806cd3e..b29255e2a 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include namespace nb = nanobind; using namespace nb::literals; @@ -116,4 +118,19 @@ void init_metal(nb::module_& m) { R"pbdoc( Stop a Metal capture. )pbdoc"); + metal.def( + "device_info", + &metal::device_info, + R"pbdoc( + Get information about the GPU device and system settings. + + Currently returns: + + * ``architecture`` + * ``max_buffer_size`` + * ``max_recommended_working_set_size`` + + Returns: + dict: A dictionary with string keys and string or integer values. + )pbdoc"); }