mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add support for ibm granite (#758)
* add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * remove unused function * rebase fix * move position emebedding to mask creation * add to tuner and format * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * rebase fix * move position emebedding to mask creation * add to tuner and format * refactor mask * remove dropout layers
This commit is contained in:
@@ -4,6 +4,13 @@ from dataclasses import dataclass
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
class KVCache:
|
||||
|
||||
def __init__(self, head_dim, n_kv_heads):
|
||||
|
||||
Reference in New Issue
Block a user