Optimizations for mamba1 (#1213)

* added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec

* Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec

* Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec

* Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec.

* cleaning up and adding apple copyright to helium modelfile

* update Copyright to this year

* nits + even faster

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Gökdeniz Gülmez
2025-02-03 22:36:08 +01:00
committed by GitHub
parent d9924d08d1
commit 0989c073b0
3 changed files with 42 additions and 26 deletions

View File

@@ -1,3 +1,5 @@
# Copyright © 2025 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple