mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Updating README for current example, making python>=3.8 compatibile, and fixing code type
This commit is contained in:
parent
20d920a7eb
commit
d873e10dfe
@ -49,12 +49,13 @@ python model.py \
|
|||||||
Which will show the following outputs:
|
Which will show the following outputs:
|
||||||
```
|
```
|
||||||
MLX BERT:
|
MLX BERT:
|
||||||
[[[-0.17057164 0.08602728 -0.12471077 ... -0.09469379 -0.00275938
|
[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694
|
||||||
0.28314582]
|
0.8227601 ]
|
||||||
[ 0.15222196 -0.48997563 -0.26665813 ... -0.19935863 -0.17162783
|
[-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603
|
||||||
-0.51360303]
|
0.31066895]
|
||||||
[ 0.9460105 0.1358298 -0.2945672 ... 0.00868467 -0.90271163
|
[-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926
|
||||||
-0.2785422 ]]]
|
0.6557663 ]
|
||||||
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
They can be compared against the 🤗 implementation with:
|
They can be compared against the 🤗 implementation with:
|
||||||
@ -67,10 +68,11 @@ python hf_model.py \
|
|||||||
Which will show:
|
Which will show:
|
||||||
```
|
```
|
||||||
HF BERT:
|
HF BERT:
|
||||||
[[[-0.17057131 0.08602707 -0.12471108 ... -0.09469365 -0.00275959
|
[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678
|
||||||
0.28314728]
|
0.8227603 ]
|
||||||
[ 0.15222463 -0.48997375 -0.26665992 ... -0.19936043 -0.17162988
|
[-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601
|
||||||
-0.5136028 ]
|
0.31066933]
|
||||||
[ 0.946011 0.13582966 -0.29456618 ... 0.00868565 -0.90271175
|
[-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025
|
||||||
-0.27854213]]]
|
0.65576553]
|
||||||
|
...
|
||||||
```
|
```
|
||||||
|
@ -85,7 +85,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
scores = (queries * scale) @ keys
|
scores = (queries * scale) @ keys
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = self.converrt_mask_to_additive_causal_mask(mask)
|
mask = self.convert_mask_to_additive_causal_mask(mask)
|
||||||
mask = mx.expand_dims(mask, (1, 2))
|
mask = mx.expand_dims(mask, (1, 2))
|
||||||
mask = mx.broadcast_to(mask, scores.shape)
|
mask = mx.broadcast_to(mask, scores.shape)
|
||||||
scores = scores + mask.astype(scores.dtype)
|
scores = scores + mask.astype(scores.dtype)
|
||||||
@ -94,7 +94,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
return self.out_proj(values_hat)
|
return self.out_proj(values_hat)
|
||||||
|
|
||||||
def converrt_mask_to_additive_causal_mask(
|
def convert_mask_to_additive_causal_mask(
|
||||||
self, mask: mx.array, dtype: mx.Dtype = mx.float32
|
self, mask: mx.array, dtype: mx.Dtype = mx.float32
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
mask = mask == 0
|
mask = mask == 0
|
||||||
@ -186,7 +186,7 @@ class Bert(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: mx.array,
|
input_ids: mx.array,
|
||||||
token_type_ids: mx.array,
|
token_type_ids: mx.array,
|
||||||
attention_mask: mx.array | None = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array]:
|
||||||
x = self.embeddings(input_ids, token_type_ids)
|
x = self.embeddings(input_ids, token_type_ids)
|
||||||
y = self.encoder(x, attention_mask)
|
y = self.encoder(x, attention_mask)
|
||||||
|
Loading…
Reference in New Issue
Block a user