Change tuple type definitions to use Tuple (#308)

This commit is contained in:
Angelos Katharopoulos
2024-01-12 11:15:09 -08:00
committed by GitHub
parent c1342b8e89
commit 1fa40067fe
3 changed files with 6 additions and 5 deletions

View File

@@ -1,5 +1,6 @@
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -128,7 +129,7 @@ class Model(nn.Module):
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])