mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Updating README for current example, making python>=3.8 compatibile, and fixing code type
This commit is contained in:
@@ -85,7 +85,7 @@ class MultiHeadAttention(nn.Module):
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys
|
||||
if mask is not None:
|
||||
mask = self.converrt_mask_to_additive_causal_mask(mask)
|
||||
mask = self.convert_mask_to_additive_causal_mask(mask)
|
||||
mask = mx.expand_dims(mask, (1, 2))
|
||||
mask = mx.broadcast_to(mask, scores.shape)
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
@@ -94,7 +94,7 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
return self.out_proj(values_hat)
|
||||
|
||||
def converrt_mask_to_additive_causal_mask(
|
||||
def convert_mask_to_additive_causal_mask(
|
||||
self, mask: mx.array, dtype: mx.Dtype = mx.float32
|
||||
) -> mx.array:
|
||||
mask = mask == 0
|
||||
@@ -186,7 +186,7 @@ class Bert(nn.Module):
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
token_type_ids: mx.array,
|
||||
attention_mask: mx.array | None = None,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.embeddings(input_ids, token_type_ids)
|
||||
y = self.encoder(x, attention_mask)
|
||||
|
Reference in New Issue
Block a user