fix: fix issue #54, use CPU device to load the Torch model

This commit is contained in:
Haixing Hu 2023-12-11 10:54:55 +08:00
parent 3a3ea3cfb0
commit 5b62270556

View File

@ -46,7 +46,7 @@ if __name__ == "__main__":
parser.add_argument("output_file") parser.add_argument("output_file")
args = parser.parse_args() args = parser.parse_args()
state = torch.load(args.torch_weights) state = torch.load(args.torch_weights, map_location=torch.device('cpu'))
np.savez( np.savez(
args.output_file, args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}