Removed debug message regarding orphaned weights

This commit is contained in:
Pawel Kowalski
2023-12-20 18:43:40 +01:00
parent dac547367d
commit de2c1022e3

View File

@@ -195,12 +195,6 @@ def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
dtype = np.float16 if float16 else np.float32
with safetensor_open(weight_file, framework="numpy") as f:
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
# debug
bar = tree_flatten(model)
missing_weights = [w[0] for w in weights if w[0] not in [b[0] for b in bar]]
if missing_weights:
print("warning: missing weights")
print(missing_weights)
model.update(tree_unflatten(weights))