mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
initial commit
This commit is contained in:
parent
8385f93cea
commit
87ec7b3cf9
@ -11,6 +11,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/random.h"
|
#include "mlx/random.h"
|
||||||
|
#include "mlx/safetensor.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
@ -1057,4 +1057,12 @@ array dequantize(
|
|||||||
int bits = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Load array from .safetensor file format */
|
||||||
|
std::map<std::string, array> load_safetensor(
|
||||||
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
std::map<std::string, array> load_safetensor(
|
||||||
|
const std::string& file,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
160
mlx/safetensor.cpp
Normal file
160
mlx/safetensor.cpp
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
#include "mlx/safetensor.h"
|
||||||
|
|
||||||
|
#include <stack>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace io {
|
||||||
|
Token Tokenizer::getToken() {
|
||||||
|
if (!this->hasMoreTokens()) {
|
||||||
|
return Token{TOKEN::NULL_TYPE};
|
||||||
|
}
|
||||||
|
char nextChar = this->_data[this->_loc];
|
||||||
|
while ((nextChar == ' ' || nextChar == '\n') && this->hasMoreTokens()) {
|
||||||
|
nextChar = this->_data[++this->_loc];
|
||||||
|
}
|
||||||
|
if (!this->hasMoreTokens()) {
|
||||||
|
return Token{TOKEN::NULL_TYPE};
|
||||||
|
}
|
||||||
|
switch (nextChar) {
|
||||||
|
case '{':
|
||||||
|
this->_loc++;
|
||||||
|
return Token{TOKEN::CURLY_OPEN};
|
||||||
|
case '}':
|
||||||
|
this->_loc++;
|
||||||
|
return Token{TOKEN::CURLY_CLOSE};
|
||||||
|
case ':':
|
||||||
|
this->_loc++;
|
||||||
|
return Token{TOKEN::COLON};
|
||||||
|
case '[':
|
||||||
|
this->_loc++;
|
||||||
|
return Token{TOKEN::ARRAY_OPEN};
|
||||||
|
case ']':
|
||||||
|
this->_loc++;
|
||||||
|
return Token{TOKEN::ARRAY_CLOSE};
|
||||||
|
case ',':
|
||||||
|
this->_loc++;
|
||||||
|
return Token{TOKEN::COMMA};
|
||||||
|
case '"': {
|
||||||
|
size_t start = this->_loc;
|
||||||
|
this->_loc++;
|
||||||
|
while (_data[this->_loc] != '"' && this->hasMoreTokens()) {
|
||||||
|
this->_loc++;
|
||||||
|
}
|
||||||
|
if (!this->hasMoreTokens()) {
|
||||||
|
throw new std::runtime_error("no more chars to parse");
|
||||||
|
}
|
||||||
|
// pass the last "
|
||||||
|
this->_loc++;
|
||||||
|
return Token{TOKEN::STRING, start, this->_loc};
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
size_t start = this->_loc;
|
||||||
|
while ((nextChar != ',' && nextChar != '}' && nextChar != ']' &&
|
||||||
|
nextChar != ' ' && nextChar != '\n') &&
|
||||||
|
this->hasMoreTokens()) {
|
||||||
|
nextChar = this->_data[++this->_loc];
|
||||||
|
}
|
||||||
|
if (!this->hasMoreTokens()) {
|
||||||
|
throw new std::runtime_error("no more chars to parse");
|
||||||
|
}
|
||||||
|
return Token{TOKEN::NUMBER, start, this->_loc};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSONNode parseJson(char* data, size_t len) {
|
||||||
|
// auto tokenizer = Tokenizer(data, len);
|
||||||
|
// std::stack<JSONNode*> ctx;
|
||||||
|
// auto token = tokenizer.getToken();
|
||||||
|
// auto parent = new JSONNode();
|
||||||
|
|
||||||
|
// switch (token.type) {
|
||||||
|
// case TOKEN::CURLY_OPEN:
|
||||||
|
// parent->setObject(new JSONObject());
|
||||||
|
// break;
|
||||||
|
// case TOKEN::ARRAY_OPEN:
|
||||||
|
// parent->setList(new JSONList());
|
||||||
|
// break;
|
||||||
|
// default:
|
||||||
|
// throw new std::runtime_error("invalid json");
|
||||||
|
// }
|
||||||
|
// ctx.push(parent);
|
||||||
|
|
||||||
|
// while (tokenizer.hasMoreTokens()) {
|
||||||
|
// auto token = tokenizer.getToken();
|
||||||
|
// switch (token.type) {
|
||||||
|
// case TOKEN::CURLY_OPEN:
|
||||||
|
// ctx.push(new JSONNode(JSONNode::Type::OBJECT));
|
||||||
|
// break;
|
||||||
|
// case TOKEN::CURLY_CLOSE:
|
||||||
|
// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
|
// auto obj = ctx.top();
|
||||||
|
// ctx.pop();
|
||||||
|
// if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
|
// auto list = ctx.top()->getList();
|
||||||
|
// list->push_back(obj);
|
||||||
|
// } else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||||
|
// //
|
||||||
|
// auto key = ctx.top();
|
||||||
|
// ctx.pop();
|
||||||
|
// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
|
// ctx.top()->getObject()->insert({key->getString(), obj});
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// } else {
|
||||||
|
// throw new std::runtime_error("invalid json");
|
||||||
|
// }
|
||||||
|
// break;
|
||||||
|
// case TOKEN::COLON:
|
||||||
|
// break;
|
||||||
|
// case TOKEN::ARRAY_OPEN:
|
||||||
|
// break;
|
||||||
|
// case TOKEN::ARRAY_CLOSE:
|
||||||
|
// break;
|
||||||
|
// case TOKEN::COMMA:
|
||||||
|
// break;
|
||||||
|
// case TOKEN::NULL_TYPE:
|
||||||
|
// break;
|
||||||
|
// case TOKEN::STRING:
|
||||||
|
// break;
|
||||||
|
// case TOKEN::NUMBER:
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
} // namespace io
|
||||||
|
|
||||||
|
/** Load array from reader in safetensor format */
|
||||||
|
std::map<std::string, array> load_safetensor(
|
||||||
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Open and check file
|
||||||
|
if (!in_stream->good() || !in_stream->is_open()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[load_safetensor] Failed to open " + in_stream->label());
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t jsonHeaderLength = 0;
|
||||||
|
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
|
||||||
|
if (jsonHeaderLength <= 0) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[load_safetensor] Invalid json header lenght " + in_stream->label());
|
||||||
|
}
|
||||||
|
// Load the json metadata
|
||||||
|
char json[jsonHeaderLength];
|
||||||
|
in_stream->read(json, jsonHeaderLength);
|
||||||
|
// Parse the json raw data
|
||||||
|
std::map<std::string, array> res;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, array> load_safetensor(
|
||||||
|
const std::string& file,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
return load_safetensor(std::make_shared<io::FileReader>(file), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
79
mlx/safetensor.h
Normal file
79
mlx/safetensor.h
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "mlx/load.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace io {
|
||||||
|
class JSONNode;
|
||||||
|
using JSONObject = std::map<std::string, std::shared_ptr<JSONNode>>;
|
||||||
|
using JSONList = std::vector<std::shared_ptr<JSONNode>>;
|
||||||
|
|
||||||
|
class JSONNode {
|
||||||
|
public:
|
||||||
|
enum class Type { OBJECT, LIST, STRING, NUMBER, BOOLEAN, NULL_TYPE };
|
||||||
|
|
||||||
|
JSONNode() : _type(Type::NULL_TYPE){};
|
||||||
|
JSONNode(Type type) : _type(type) {
|
||||||
|
// set the default value
|
||||||
|
if (type == Type::OBJECT) {
|
||||||
|
this->_values.object = new JSONObject();
|
||||||
|
} else if (type == Type::LIST) {
|
||||||
|
this->_values.list = new JSONList();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline bool is_type(Type t) {
|
||||||
|
return this->_type == t;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
union Values {
|
||||||
|
JSONObject* object;
|
||||||
|
JSONList* list;
|
||||||
|
std::string* s;
|
||||||
|
float fValue;
|
||||||
|
} _values;
|
||||||
|
Type _type;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class TOKEN {
|
||||||
|
CURLY_OPEN,
|
||||||
|
CURLY_CLOSE,
|
||||||
|
COLON,
|
||||||
|
STRING,
|
||||||
|
NUMBER,
|
||||||
|
ARRAY_OPEN,
|
||||||
|
ARRAY_CLOSE,
|
||||||
|
COMMA,
|
||||||
|
NULL_TYPE,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Token {
|
||||||
|
TOKEN type;
|
||||||
|
size_t start;
|
||||||
|
size_t end;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Tokenizer {
|
||||||
|
public:
|
||||||
|
Tokenizer(const char* data, size_t len) : _data(data), _loc(0), _len(len){};
|
||||||
|
Token getToken();
|
||||||
|
inline bool hasMoreTokens() {
|
||||||
|
return this->_loc < this->_len;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
const char* _data;
|
||||||
|
size_t _len;
|
||||||
|
size_t _loc;
|
||||||
|
};
|
||||||
|
} // namespace io
|
||||||
|
} // namespace mlx::core
|
@ -14,6 +14,25 @@ std::string get_temp_file(const std::string& name) {
|
|||||||
return std::filesystem::temp_directory_path().append(name);
|
return std::filesystem::temp_directory_path().append(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test tokenizer") {
|
||||||
|
auto raw = std::string(" { \"testing\": [1 , \"test\"]} ");
|
||||||
|
auto tokenizer = io::Tokenizer(raw.c_str(), raw.size());
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_OPEN);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NUMBER);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COMMA);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TEST_CASE("test load_safetensor") {
|
||||||
|
// auto array = load_safetensor("../../temp.safe");
|
||||||
|
// }
|
||||||
|
|
||||||
TEST_CASE("test single array serialization") {
|
TEST_CASE("test single array serialization") {
|
||||||
// Basic test
|
// Basic test
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user