Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng
2024-12-12 08:45:39 +09:00
committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
22 changed files with 1423 additions and 1302 deletions

View File

@@ -14,9 +14,9 @@
#include "python/src/load.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
///////////////////////////////////////////////////////////////////////////////
// Helpers
@@ -86,7 +86,7 @@ class ZipFileWrapper {
// Loading
///////////////////////////////////////////////////////////////////////////////
class PyFileReader : public io::Reader {
class PyFileReader : public mx::io::Reader {
public:
PyFileReader(nb::object file)
: pyistream_(file),
@@ -168,14 +168,14 @@ class PyFileReader : public io::Reader {
};
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, mx::array>,
std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s);
return mx::load_safetensors(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);
{
nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) {
@@ -189,17 +189,17 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
"[load_safetensors] Input must be a file-like object, or string");
}
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s);
return mx::load_gguf(nb::cast<std::string>(file), s);
}
throw std::invalid_argument("[load_gguf] Input must be a string");
}
std::unordered_map<std::string, array> mlx_load_npz_helper(
std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
nb::object file,
StreamOrDevice s) {
mx::StreamOrDevice s) {
bool own_file = nb::isinstance<nb::str>(file);
nb::module_ zipfile = nb::module_::import_("zipfile");
@@ -209,7 +209,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
"opened with zipfile.ZipFile");
}
// Output dictionary filename in zip -> loaded array
std::unordered_map<std::string, array> array_dict;
std::unordered_map<std::string, mx::array> array_dict;
// Create python ZipFile object
ZipFileWrapper zipfile_object(zipfile, file);
@@ -218,7 +218,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
nb::object sub_file = zipfile_object.open(st);
// Create array from python file stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
auto arr = mx::load(std::make_shared<PyFileReader>(sub_file), s);
// Remove .npy from file if it is there
auto key = st;
@@ -240,12 +240,12 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
return array_dict;
}
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
return load(nb::cast<std::string>(file), s);
return mx::load(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s);
auto arr = mx::load(std::make_shared<PyFileReader>(file), s);
{
nb::gil_scoped_release gil;
arr.eval();
@@ -260,7 +260,7 @@ LoadOutputTypes mlx_load_helper(
nb::object file,
std::optional<std::string> format,
bool return_metadata,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (!format.has_value()) {
std::string fname;
if (nb::isinstance<nb::str>(file)) {
@@ -309,7 +309,7 @@ LoadOutputTypes mlx_load_helper(
// Saving
///////////////////////////////////////////////////////////////////////////////
class PyFileWriter : public io::Writer {
class PyFileWriter : public mx::io::Writer {
public:
PyFileWriter(nb::object file)
: pyostream_(file),
@@ -382,15 +382,15 @@ class PyFileWriter : public io::Writer {
nb::object tell_func_;
};
void mlx_save_helper(nb::object file, array a) {
void mlx_save_helper(nb::object file, mx::array a) {
if (nb::isinstance<nb::str>(file)) {
save(nb::cast<std::string>(file), a);
mx::save(nb::cast<std::string>(file), a);
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
nb::gil_scoped_release gil;
save(writer, a);
mx::save(writer, a);
}
return;
@@ -419,8 +419,9 @@ void mlx_savez_helper(
}
// Collect args and kwargs
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
auto arrays_list = nb::cast<std::vector<array>>(args);
auto arrays_dict =
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
auto arrays_list = nb::cast<std::vector<mx::array>>(args);
for (int i = 0; i < arrays_list.size(); i++) {
std::string arr_name = "arr_" + std::to_string(i);
@@ -447,7 +448,7 @@ void mlx_savez_helper(
auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
nb::gil_scoped_release nogil;
save(writer, a);
mx::save(writer, a);
}
}
@@ -470,17 +471,18 @@ void mlx_save_safetensor_helper(
} else {
metadata_map = std::unordered_map<std::string, std::string>();
}
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);
if (nb::isinstance<nb::str>(file)) {
{
nb::gil_scoped_release nogil;
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
mx::save_safetensors(
nb::cast<std::string>(file), arrays_map, metadata_map);
}
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
nb::gil_scoped_release nogil;
save_safetensors(writer, arrays_map, metadata_map);
mx::save_safetensors(writer, arrays_map, metadata_map);
}
} else {
throw std::invalid_argument(
@@ -492,19 +494,20 @@ void mlx_save_gguf_helper(
nb::object file,
nb::dict a,
std::optional<nb::dict> m) {
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);
if (nb::isinstance<nb::str>(file)) {
if (m) {
auto metadata_map =
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(
m.value());
{
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
}
} else {
{
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map);
mx::save_gguf(nb::cast<std::string>(file), arrays_map);
}
}
} else {