mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
Ran pre-commit
This commit is contained in:
parent
273d558117
commit
a415bac4f7
@ -88,5 +88,3 @@ class Dropout2d(Module):
|
|||||||
|
|
||||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||||
return (1 / self._p_1) * mask * x
|
return (1 / self._p_1) * mask * x
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +25,13 @@ class RoPE(Module):
|
|||||||
each dimension in the positional encodings. Default: ``10000``
|
each dimension in the positional encodings. Default: ``10000``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000, scale: float = 1.0):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
traditional: bool = False,
|
||||||
|
base: float = 10000,
|
||||||
|
scale: float = 1.0,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.traditional = traditional
|
self.traditional = traditional
|
||||||
@ -80,7 +86,12 @@ class RoPE(Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_cos_sin_theta(
|
def create_cos_sin_theta(
|
||||||
N: int, D: int, offset: int = 0, base: float = 10000, scale: float = 1.0, dtype=mx.float32
|
N: int,
|
||||||
|
D: int,
|
||||||
|
offset: int = 0,
|
||||||
|
base: float = 10000,
|
||||||
|
scale: float = 1.0,
|
||||||
|
dtype=mx.float32,
|
||||||
):
|
):
|
||||||
D = D // 2
|
D = D // 2
|
||||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||||
|
Loading…
Reference in New Issue
Block a user