Change tuple type definitions to use Tuple (#308)

This commit is contained in:
Angelos Katharopoulos 2024-01-12 11:15:09 -08:00 committed by GitHub
parent c1342b8e89
commit 1fa40067fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 5 deletions

View File

@ -1,7 +1,7 @@
import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@ -120,7 +120,7 @@ class Bert(nn.Module):
input_ids: mx.array,
token_type_ids: mx.array,
attention_mask: mx.array = None,
) -> tuple[mx.array, mx.array]:
) -> Tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
if attention_mask is not None:
@ -132,7 +132,7 @@ class Bert(nn.Module):
return y, mx.tanh(self.pooler(y[:, 0]))
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
def load_model(bert_model: str, weights_path: str) -> Tuple[Bert, BertTokenizer]:
if not Path(weights_path).exists():
raise ValueError(f"No model weights found in {weights_path}")

View File

@ -1,5 +1,6 @@
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@ -128,7 +129,7 @@ class Model(nn.Module):
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])

View File

@ -8,7 +8,7 @@ with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
setup(
name="mlx-lm",
version="0.0.1",
version="0.0.2",
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",