Make sure gguf_ctx is closed when error happens (#1699)

This commit is contained in:
Cheng 2024-12-14 12:50:19 +09:00 committed by GitHub
parent dfccd17ab9
commit 4768c61b57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -241,13 +241,13 @@ GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
throw std::invalid_argument("[load_gguf] Failed to open " + file);
}
gguf_ctx* ctx = gguf_open(file.data());
std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(
gguf_open(file.data()), gguf_close);
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");
}
auto metadata = load_metadata(ctx);
auto arrays = load_arrays(ctx);
gguf_close(ctx);
auto metadata = load_metadata(ctx.get());
auto arrays = load_arrays(ctx.get());
return {arrays, metadata};
}
@ -293,7 +293,8 @@ void save_gguf(
file += ".gguf";
}
gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE);
std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(
gguf_create(file.c_str(), GGUF_OVERWRITE), gguf_close);
if (!ctx) {
throw std::runtime_error("[save_gguf] gguf_create failed");
}
@ -312,7 +313,7 @@ void save_gguf(
std::vector<char> val_vec(size);
string_to_gguf(val_vec.data(), str);
gguf_append_kv(
ctx,
ctx.get(),
key.c_str(),
key.length(),
GGUF_VALUE_TYPE_STRING,
@ -335,7 +336,7 @@ void save_gguf(
str_ptr += str.length() + sizeof(gguf_string);
}
gguf_append_kv(
ctx,
ctx.get(),
key.c_str(),
key.length(),
GGUF_VALUE_TYPE_ARRAY,
@ -361,34 +362,34 @@ void save_gguf(
}
switch (v.dtype()) {
case float32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_FLOAT32);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_FLOAT32);
break;
case int64:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT64);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT64);
break;
case int32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT32);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT32);
break;
case int16:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT16);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT16);
break;
case int8:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT8);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT8);
break;
case uint64:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT64);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT64);
break;
case uint32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT32);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT32);
break;
case uint16:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT16);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT16);
break;
case uint8:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT8);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT8);
break;
case bool_:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_BOOL);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_BOOL);
break;
default:
std::ostringstream msg;
@ -438,7 +439,7 @@ void save_gguf(
dim[i] = arr.shape()[num_dim - 1 - i];
}
if (!gguf_append_tensor_info(
ctx,
ctx.get(),
tensorname,
namelen,
num_dim,
@ -452,11 +453,11 @@ void save_gguf(
// Then, append the tensor weights
for (const auto& [key, arr] : array_map) {
if (!gguf_append_tensor_data(ctx, (void*)arr.data<void>(), arr.nbytes())) {
if (!gguf_append_tensor_data(
ctx.get(), (void*)arr.data<void>(), arr.nbytes())) {
throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed");
}
}
gguf_close(ctx);
}
} // namespace mlx::core