mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Make sure gguf_ctx is closed when error happens (#1699)
This commit is contained in:
parent
dfccd17ab9
commit
4768c61b57
@ -241,13 +241,13 @@ GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
throw std::invalid_argument("[load_gguf] Failed to open " + file);
|
||||
}
|
||||
|
||||
gguf_ctx* ctx = gguf_open(file.data());
|
||||
std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(
|
||||
gguf_open(file.data()), gguf_close);
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
}
|
||||
auto metadata = load_metadata(ctx);
|
||||
auto arrays = load_arrays(ctx);
|
||||
gguf_close(ctx);
|
||||
auto metadata = load_metadata(ctx.get());
|
||||
auto arrays = load_arrays(ctx.get());
|
||||
return {arrays, metadata};
|
||||
}
|
||||
|
||||
@ -293,7 +293,8 @@ void save_gguf(
|
||||
file += ".gguf";
|
||||
}
|
||||
|
||||
gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE);
|
||||
std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(
|
||||
gguf_create(file.c_str(), GGUF_OVERWRITE), gguf_close);
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[save_gguf] gguf_create failed");
|
||||
}
|
||||
@ -312,7 +313,7 @@ void save_gguf(
|
||||
std::vector<char> val_vec(size);
|
||||
string_to_gguf(val_vec.data(), str);
|
||||
gguf_append_kv(
|
||||
ctx,
|
||||
ctx.get(),
|
||||
key.c_str(),
|
||||
key.length(),
|
||||
GGUF_VALUE_TYPE_STRING,
|
||||
@ -335,7 +336,7 @@ void save_gguf(
|
||||
str_ptr += str.length() + sizeof(gguf_string);
|
||||
}
|
||||
gguf_append_kv(
|
||||
ctx,
|
||||
ctx.get(),
|
||||
key.c_str(),
|
||||
key.length(),
|
||||
GGUF_VALUE_TYPE_ARRAY,
|
||||
@ -361,34 +362,34 @@ void save_gguf(
|
||||
}
|
||||
switch (v.dtype()) {
|
||||
case float32:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_FLOAT32);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_FLOAT32);
|
||||
break;
|
||||
case int64:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT64);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT64);
|
||||
break;
|
||||
case int32:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT32);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT32);
|
||||
break;
|
||||
case int16:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT16);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT16);
|
||||
break;
|
||||
case int8:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT8);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT8);
|
||||
break;
|
||||
case uint64:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT64);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT64);
|
||||
break;
|
||||
case uint32:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT32);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT32);
|
||||
break;
|
||||
case uint16:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT16);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT16);
|
||||
break;
|
||||
case uint8:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT8);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT8);
|
||||
break;
|
||||
case bool_:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_BOOL);
|
||||
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_BOOL);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream msg;
|
||||
@ -438,7 +439,7 @@ void save_gguf(
|
||||
dim[i] = arr.shape()[num_dim - 1 - i];
|
||||
}
|
||||
if (!gguf_append_tensor_info(
|
||||
ctx,
|
||||
ctx.get(),
|
||||
tensorname,
|
||||
namelen,
|
||||
num_dim,
|
||||
@ -452,11 +453,11 @@ void save_gguf(
|
||||
|
||||
// Then, append the tensor weights
|
||||
for (const auto& [key, arr] : array_map) {
|
||||
if (!gguf_append_tensor_data(ctx, (void*)arr.data<void>(), arr.nbytes())) {
|
||||
if (!gguf_append_tensor_data(
|
||||
ctx.get(), (void*)arr.data<void>(), arr.nbytes())) {
|
||||
throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed");
|
||||
}
|
||||
}
|
||||
gguf_close(ctx);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user