Add support for ibm granite (#758)

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* remove unused function

* rebase fix

* move position emebedding to mask creation

* add to tuner and format

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* rebase fix

* move position emebedding to mask creation

* add to tuner and format

* refactor mask

* remove dropout layers
This commit is contained in:
Prince Canuma
2024-05-22 05:16:31 +02:00
committed by GitHub
parent 9fc6efbd90
commit b044ce2acf
4 changed files with 238 additions and 20 deletions

View File

@@ -4,6 +4,13 @@ from dataclasses import dataclass
import mlx.core as mx
def create_additive_causal_mask(N: int, offset: int = 0):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
return mask * -1e9
class KVCache:
def __init__(self, head_dim, n_kv_heads):