Change tuple type definitions to use Tuple (#308)

This commit is contained in:
Angelos Katharopoulos
2024-01-12 11:15:09 -08:00
committed by GitHub
parent c1342b8e89
commit 1fa40067fe
3 changed files with 6 additions and 5 deletions

View File

@@ -1,7 +1,7 @@
import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -120,7 +120,7 @@ class Bert(nn.Module):
input_ids: mx.array,
token_type_ids: mx.array,
attention_mask: mx.array = None,
) -> tuple[mx.array, mx.array]:
) -> Tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
if attention_mask is not None:
@@ -132,7 +132,7 @@ class Bert(nn.Module):
return y, mx.tanh(self.pooler(y[:, 0]))
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
def load_model(bert_model: str, weights_path: str) -> Tuple[Bert, BertTokenizer]:
if not Path(weights_path).exists():
raise ValueError(f"No model weights found in {weights_path}")