mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
reformatted
This commit is contained in:
parent
d7cab9d5f5
commit
298178d669
@ -8,6 +8,7 @@ import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import torch
|
||||
@ -149,7 +150,7 @@ def quantize(weights, config, args):
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
shards = []
|
||||
shard : Dict[str, mx.array] = {}
|
||||
shard: Dict[str, mx.array] = {}
|
||||
shard_size = 0
|
||||
for k, v in weights.items():
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
|
@ -91,7 +91,7 @@ class FeedForward(nn.Module):
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
|
||||
if args.moe is None:
|
||||
raise ValueError("args.moe must not be None for MOEFeedForward")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user