format with later version of black

This commit is contained in:
Awni Hannun
2024-01-03 14:59:45 -08:00
parent d097652adc
commit 99581115a0
6 changed files with 16 additions and 8 deletions

View File

@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 22.10.0
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@@ -32,7 +32,6 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
def main(args):
# Data loading
x, y, adj = load_data(args)
train_mask, val_mask, test_mask = train_val_test_mask()
@@ -55,7 +54,6 @@ def main(args):
# Training loop
for epoch in range(args.epochs):
# Loss
(loss, y_hat), grads = loss_and_grad_fn(
gcn, x, adj, y, train_mask, args.weight_decay
@@ -96,7 +94,6 @@ def main(args):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
parser.add_argument("--edges_path", type=str, default="cora/cora.cites")

View File

@@ -285,7 +285,11 @@ if __name__ == "__main__":
model, tokenizer = load_model(args.model_path)
prompt = tokenizer(args.prompt, return_tensors="np", return_attention_mask=False,)[
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]

View File

@@ -15,7 +15,11 @@ def generate(
max_tokens: int,
temp: float = 0.0,
):
prompt = tokenizer(args.prompt, return_tensors="np", return_attention_mask=False,)[
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
prompt = mx.array(prompt)

View File

@@ -27,7 +27,11 @@ class Tokenizer:
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
].squeeze(0)
)

View File

@@ -381,7 +381,6 @@ class UNetModel(nn.Module):
)
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)