diff --git a/mlx/io/load.h b/mlx/io/load.h index 8aa80bbb7..637df1b19 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -14,7 +14,7 @@ class Reader { public: virtual bool is_open() const = 0; virtual bool good() const = 0; - virtual size_t tell() const = 0; + virtual size_t tell() = 0; // tellp is non-const in iostream virtual void seek( int64_t off, std::ios_base::seekdir way = std::ios_base::beg) = 0; @@ -26,7 +26,7 @@ class Writer { public: virtual bool is_open() const = 0; virtual bool good() const = 0; - virtual size_t tell() const = 0; + virtual size_t tell() = 0; virtual void seek( int64_t off, std::ios_base::seekdir way = std::ios_base::beg) = 0; @@ -36,31 +36,31 @@ class Writer { class FileReader : public Reader { public: - explicit FileReader(const std::shared_ptr& is) - : is_(is), label_("stream") {} - explicit FileReader(const std::string& file_path) - : is_(std::make_shared(file_path, std::ios::binary)), - label_(file_path) {} + explicit FileReader(std::ifstream is) + : is_(std::move(is)), label_("stream") {} + explicit FileReader(std::string file_path) + : is_(std::ifstream(file_path, std::ios::binary)), + label_(std::move(file_path)) {} bool is_open() const override { - return is_->is_open(); + return is_.is_open(); } bool good() const override { - return is_->good(); + return is_.good(); } - size_t tell() const override { - return is_->tellg(); + size_t tell() override { + return is_.tellg(); } void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) override { - is_->seekg(off, way); + is_.seekg(off, way); } void read(char* data, size_t n) override { - is_->read(data, n); + is_.read(data, n); } std::string label() const override { @@ -68,37 +68,37 @@ class FileReader : public Reader { } private: - std::shared_ptr is_; + std::ifstream is_; std::string label_; }; class FileWriter : public Writer { public: - explicit FileWriter(const std::shared_ptr& is) - : os_(is), label_("stream") {} - explicit FileWriter(const std::string& file_path) - : os_(std::make_shared(file_path, std::ios::binary)), - label_(file_path) {} + explicit FileWriter(std::ofstream os) + : os_(std::move(os)), label_("stream") {} + explicit FileWriter(std::string file_path) + : os_(std::ofstream(file_path, std::ios::binary)), + label_(std::move(file_path)) {} bool is_open() const override { - return os_->is_open(); + return os_.is_open(); } bool good() const override { - return os_->good(); + return os_.good(); } - size_t tell() const override { - return os_->tellp(); + size_t tell() override { + return os_.tellp(); } void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) override { - os_->seekp(off, way); + os_.seekp(off, way); } void write(const char* data, size_t n) override { - os_->write(data, n); + os_.write(data, n); } std::string label() const override { @@ -106,7 +106,7 @@ class FileWriter : public Writer { } private: - std::shared_ptr os_; + std::ofstream os_; std::string label_; }; diff --git a/python/src/load.cpp b/python/src/load.cpp index db0804e61..efad3d97d 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -121,7 +121,7 @@ class PyFileReader : public io::Reader { return out; } - size_t tell() const override { + size_t tell() override { size_t out; { nb::gil_scoped_acquire gil; @@ -334,7 +334,7 @@ class PyFileWriter : public io::Writer { return out; } - size_t tell() const override { + size_t tell() override { size_t out; { nb::gil_scoped_acquire gil;