Ran pre-commit

This commit is contained in:
Hazem 2023-12-23 00:06:59 +02:00
parent 273d558117
commit a415bac4f7
2 changed files with 13 additions and 4 deletions

View File

@ -88,5 +88,3 @@ class Dropout2d(Module):
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x

View File

@ -25,7 +25,13 @@ class RoPE(Module):
each dimension in the positional encodings. Default: ``10000``
"""
def __init__(self, dims: int, traditional: bool = False, base: float = 10000, scale: float = 1.0):
def __init__(
self,
dims: int,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.dims = dims
self.traditional = traditional
@ -80,7 +86,12 @@ class RoPE(Module):
@staticmethod
def create_cos_sin_theta(
N: int, D: int, offset: int = 0, base: float = 10000, scale: float = 1.0, dtype=mx.float32
N: int,
D: int,
offset: int = 0,
base: float = 10000,
scale: float = 1.0,
dtype=mx.float32,
):
D = D // 2
positions = mx.arange(offset, N, dtype=dtype) * scale