mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
T5: Change default dtype to bfloat16 (#147)
* T5: Change default to bfloat16 * Add myself to contributors * t5: Change convert.py default to float32
This commit is contained in:
parent
62b455f801
commit
ebbb7083cc
@ -6,3 +6,5 @@ with a short description of your contribution(s) below. For example:
|
|||||||
- Jane Smith: Added the `foo` example.
|
- Jane Smith: Added the `foo` example.
|
||||||
|
|
||||||
MLX Examples was developed with contributions from the following individuals:
|
MLX Examples was developed with contributions from the following individuals:
|
||||||
|
|
||||||
|
- Juarez Bochi: Added support for T5 models.
|
||||||
|
@ -44,13 +44,15 @@ def replace_key(key: str) -> str:
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
def convert(model_name):
|
def convert(model_name, dtype):
|
||||||
|
dtype = getattr(np, dtype)
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||||
weights = {
|
weights = {
|
||||||
replace_key(k): v.numpy().astype(np.float16)
|
replace_key(k): v.numpy().astype(dtype)
|
||||||
for k, v in model.state_dict().items()
|
for k, v in model.state_dict().items()
|
||||||
}
|
}
|
||||||
file_name = model_name.replace("/", "-")
|
file_name = model_name.replace("/", "-")
|
||||||
|
print(f"Saving weights to {file_name}.npz")
|
||||||
np.savez(f"{file_name}.npz", **weights)
|
np.savez(f"{file_name}.npz", **weights)
|
||||||
|
|
||||||
|
|
||||||
@ -64,5 +66,12 @@ if __name__ == "__main__":
|
|||||||
help="Name of the T5 model.",
|
help="Name of the T5 model.",
|
||||||
default="t5-small",
|
default="t5-small",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
help="The model data type.",
|
||||||
|
type=str,
|
||||||
|
choices=["float16", "float32"],
|
||||||
|
default="float32",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert(args.model)
|
convert(args.model, args.dtype)
|
||||||
|
4
t5/t5.py
4
t5/t5.py
@ -337,7 +337,7 @@ class Tokenizer:
|
|||||||
self._tokenizer = T5Tokenizer.from_pretrained(
|
self._tokenizer = T5Tokenizer.from_pretrained(
|
||||||
args.model,
|
args.model,
|
||||||
legacy=False,
|
legacy=False,
|
||||||
model_max_length=config.n_positions,
|
model_max_length=getattr(config, 'n_positions', 512)
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -430,7 +430,7 @@ if __name__ == "__main__":
|
|||||||
help="The model data type.",
|
help="The model data type.",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["float16", "bfloat16", "float32"],
|
choices=["float16", "bfloat16", "float32"],
|
||||||
default="float32",
|
default="bfloat16",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
Loading…
Reference in New Issue
Block a user