mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Change tuple type definitions to use Tuple (#308)
This commit is contained in:
parent
c1342b8e89
commit
1fa40067fe
@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -120,7 +120,7 @@ class Bert(nn.Module):
|
|||||||
input_ids: mx.array,
|
input_ids: mx.array,
|
||||||
token_type_ids: mx.array,
|
token_type_ids: mx.array,
|
||||||
attention_mask: mx.array = None,
|
attention_mask: mx.array = None,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
x = self.embeddings(input_ids, token_type_ids)
|
x = self.embeddings(input_ids, token_type_ids)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -132,7 +132,7 @@ class Bert(nn.Module):
|
|||||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||||
|
|
||||||
|
|
||||||
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
def load_model(bert_model: str, weights_path: str) -> Tuple[Bert, BertTokenizer]:
|
||||||
if not Path(weights_path).exists():
|
if not Path(weights_path).exists():
|
||||||
raise ValueError(f"No model weights found in {weights_path}")
|
raise ValueError(f"No model weights found in {weights_path}")
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -128,7 +129,7 @@ class Model(nn.Module):
|
|||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache: mx.array = None,
|
cache: mx.array = None,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
mask = None
|
mask = None
|
||||||
if x.shape[1] > 1:
|
if x.shape[1] > 1:
|
||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
|
@ -8,7 +8,7 @@ with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
|
|||||||
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
|
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
|
||||||
setup(
|
setup(
|
||||||
name="mlx-lm",
|
name="mlx-lm",
|
||||||
version="0.0.1",
|
version="0.0.2",
|
||||||
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
|
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
|
||||||
long_description=open("README.md", encoding="utf-8").read(),
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
Loading…
Reference in New Issue
Block a user