Merge pull request #66 from Haixing-Hu/fix-issue-54

fix: fix issue #54, use CPU device to load the Torch model
This commit is contained in:
Awni Hannun 2023-12-10 18:57:51 -08:00 committed by GitHub
commit ecd96acfe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -46,7 +46,7 @@ if __name__ == "__main__":
parser.add_argument("output_file") parser.add_argument("output_file")
args = parser.parse_args() args = parser.parse_args()
state = torch.load(args.torch_weights) state = torch.load(args.torch_weights, map_location=torch.device('cpu'))
np.savez( np.savez(
args.output_file, args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}