mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
format with later version of black
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 22.10.0
|
||||
rev: 23.12.1
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@@ -32,7 +32,6 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
# Data loading
|
||||
x, y, adj = load_data(args)
|
||||
train_mask, val_mask, test_mask = train_val_test_mask()
|
||||
@@ -55,7 +54,6 @@ def main(args):
|
||||
|
||||
# Training loop
|
||||
for epoch in range(args.epochs):
|
||||
|
||||
# Loss
|
||||
(loss, y_hat), grads = loss_and_grad_fn(
|
||||
gcn, x, adj, y, train_mask, args.weight_decay
|
||||
@@ -96,7 +94,6 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
|
||||
parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
|
||||
|
@@ -285,7 +285,11 @@ if __name__ == "__main__":
|
||||
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
|
||||
prompt = tokenizer(args.prompt, return_tensors="np", return_attention_mask=False,)[
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)[
|
||||
"input_ids"
|
||||
][0]
|
||||
|
||||
|
@@ -15,7 +15,11 @@ def generate(
|
||||
max_tokens: int,
|
||||
temp: float = 0.0,
|
||||
):
|
||||
prompt = tokenizer(args.prompt, return_tensors="np", return_attention_mask=False,)[
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)[
|
||||
"input_ids"
|
||||
][0]
|
||||
prompt = mx.array(prompt)
|
||||
|
@@ -27,7 +27,11 @@ class Tokenizer:
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array(
|
||||
self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[
|
||||
self._tokenizer(
|
||||
s,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)[
|
||||
"input_ids"
|
||||
].squeeze(0)
|
||||
)
|
||||
|
@@ -381,7 +381,6 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
|
||||
|
||||
# Compute the time embeddings
|
||||
temb = self.timesteps(timestep).astype(x.dtype)
|
||||
temb = self.time_embedding(temb)
|
||||
|
Reference in New Issue
Block a user