mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 12:08:14 +08:00

* add mode parameter for quantization * mxfp4 quantize/dequantize + start of optional biases * mxfp4 works * speedup * cpu mxfp4 * fix * fix test tol * fix * refactor * add quant mode enum
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
import math
|
|
|
|
import mlx.core as mx
|
|
from mlx.nn.layers.base import Module
|
|
from mlx.nn.layers.quantized import QuantizedEmbedding
|
|
|
|
|
|
class Embedding(Module):
|
|
"""Implements a simple lookup table that maps each input integer to a
|
|
high-dimensional vector.
|
|
|
|
Typically used to embed discrete tokens for processing by neural networks.
|
|
|
|
Args:
|
|
num_embeddings (int): How many possible discrete tokens can we embed.
|
|
Usually called the vocabulary size.
|
|
dims (int): The dimensionality of the embeddings.
|
|
"""
|
|
|
|
def __init__(self, num_embeddings: int, dims: int):
|
|
super().__init__()
|
|
scale = math.sqrt(1 / dims)
|
|
self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
|
|
|
def _extra_repr(self):
|
|
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
|
|
|
|
def __call__(self, x):
|
|
return self.weight[x]
|
|
|
|
def as_linear(self, x):
|
|
"""
|
|
Call the embedding layer as a linear layer.
|
|
|
|
Use this for example when input embedding and output projection
|
|
weights are tied.
|
|
"""
|
|
return x @ self.weight.T
|
|
|
|
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
|
|
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
|
return QuantizedEmbedding.from_embedding(self, group_size, bits, mode)
|