Change tuple type definitions to use Tuple (#308)

This commit is contained in:
Angelos Katharopoulos 2024-01-12 11:15:09 -08:00 committed by GitHub
parent c1342b8e89
commit 1fa40067fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 5 deletions

View File

@ -1,7 +1,7 @@
import argparse import argparse
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -120,7 +120,7 @@ class Bert(nn.Module):
input_ids: mx.array, input_ids: mx.array,
token_type_ids: mx.array, token_type_ids: mx.array,
attention_mask: mx.array = None, attention_mask: mx.array = None,
) -> tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids) x = self.embeddings(input_ids, token_type_ids)
if attention_mask is not None: if attention_mask is not None:
@ -132,7 +132,7 @@ class Bert(nn.Module):
return y, mx.tanh(self.pooler(y[:, 0])) return y, mx.tanh(self.pooler(y[:, 0]))
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]: def load_model(bert_model: str, weights_path: str) -> Tuple[Bert, BertTokenizer]:
if not Path(weights_path).exists(): if not Path(weights_path).exists():
raise ValueError(f"No model weights found in {weights_path}") raise ValueError(f"No model weights found in {weights_path}")

View File

@ -1,5 +1,6 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -128,7 +129,7 @@ class Model(nn.Module):
x: mx.array, x: mx.array,
mask: mx.array = None, mask: mx.array = None,
cache: mx.array = None, cache: mx.array = None,
) -> tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
mask = None mask = None
if x.shape[1] > 1: if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])

View File

@ -8,7 +8,7 @@ with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)] requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
setup( setup(
name="mlx-lm", name="mlx-lm",
version="0.0.1", version="0.0.2",
description="LLMs on Apple silicon with MLX and the Hugging Face Hub", description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
long_description=open("README.md", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",