Updating README for current example, making python>=3.8 compatibile, and fixing code type

This commit is contained in:
Joe Barrow
2023-12-09 12:01:58 -05:00
parent 20d920a7eb
commit d873e10dfe
2 changed files with 17 additions and 15 deletions

View File

@@ -85,7 +85,7 @@ class MultiHeadAttention(nn.Module):
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
mask = self.converrt_mask_to_additive_causal_mask(mask)
mask = self.convert_mask_to_additive_causal_mask(mask)
mask = mx.expand_dims(mask, (1, 2))
mask = mx.broadcast_to(mask, scores.shape)
scores = scores + mask.astype(scores.dtype)
@@ -94,7 +94,7 @@ class MultiHeadAttention(nn.Module):
return self.out_proj(values_hat)
def converrt_mask_to_additive_causal_mask(
def convert_mask_to_additive_causal_mask(
self, mask: mx.array, dtype: mx.Dtype = mx.float32
) -> mx.array:
mask = mask == 0
@@ -186,7 +186,7 @@ class Bert(nn.Module):
self,
input_ids: mx.array,
token_type_ids: mx.array,
attention_mask: mx.array | None = None,
attention_mask: Optional[mx.array] = None,
) -> tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
y = self.encoder(x, attention_mask)