mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
9f0ba3ddf1
...
5bcf3a6794
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bcf3a6794 | ||
|
|
7707196297 | ||
|
|
7e3471c987 |
@@ -57,9 +57,16 @@ Shape get_shape(const gguf_tensor& tensor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
|
std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[extract_tensor_data] Input tensor pointer is null.");
|
||||||
|
}
|
||||||
std::optional<Dtype> equivalent_dtype = gguf_type_to_dtype(tensor->type);
|
std::optional<Dtype> equivalent_dtype = gguf_type_to_dtype(tensor->type);
|
||||||
// If there's an equivalent type, we can simply copy.
|
// If there's an equivalent type, we can simply copy.
|
||||||
if (equivalent_dtype.has_value()) {
|
if (equivalent_dtype.has_value()) {
|
||||||
|
if (tensor->weights_data == nullptr) {
|
||||||
|
throw std::runtime_error("[load_gguf] NULL tensor data pointer");
|
||||||
|
}
|
||||||
allocator::Buffer buffer = allocator::malloc(tensor->bsize);
|
allocator::Buffer buffer = allocator::malloc(tensor->bsize);
|
||||||
memcpy(
|
memcpy(
|
||||||
buffer.raw_ptr(),
|
buffer.raw_ptr(),
|
||||||
|
|||||||
@@ -265,7 +265,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
std::vector<char> buffer(header_len + 1);
|
std::vector<char> buffer(header_len + 1);
|
||||||
in_stream->read(&buffer[0], header_len);
|
in_stream->read(&buffer[0], header_len);
|
||||||
buffer[header_len] = 0;
|
buffer[header_len] = 0;
|
||||||
std::string header(&buffer[0]);
|
std::string header(buffer.data(), header_len);
|
||||||
|
|
||||||
// Read data type from header
|
// Read data type from header
|
||||||
std::string dtype_str = header.substr(11, 3);
|
std::string dtype_str = header.substr(11, 3);
|
||||||
@@ -273,7 +273,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
Dtype dtype = dtype_from_array_protocol(dtype_str);
|
Dtype dtype = dtype_from_array_protocol(dtype_str);
|
||||||
|
|
||||||
// Read contiguity order
|
// Read contiguity order
|
||||||
bool col_contiguous = header[34] == 'T';
|
bool col_contiguous = header.at(34) == 'T';
|
||||||
|
|
||||||
// Read array shape from header
|
// Read array shape from header
|
||||||
Shape shape;
|
Shape shape;
|
||||||
|
|||||||
Reference in New Issue
Block a user