mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Change tuple type definitions to use Tuple (#308)
This commit is contained in:
committed by
GitHub
parent
c1342b8e89
commit
1fa40067fe
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -120,7 +120,7 @@ class Bert(nn.Module):
|
||||
input_ids: mx.array,
|
||||
token_type_ids: mx.array,
|
||||
attention_mask: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
x = self.embeddings(input_ids, token_type_ids)
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -132,7 +132,7 @@ class Bert(nn.Module):
|
||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||
|
||||
|
||||
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
||||
def load_model(bert_model: str, weights_path: str) -> Tuple[Bert, BertTokenizer]:
|
||||
if not Path(weights_path).exists():
|
||||
raise ValueError(f"No model weights found in {weights_path}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user