Files
mlx/python/mlx/nn/layers/embedding.py
Awni Hannun 70560b6bd5 Add mode parameter for quantization (#2499)
* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
2025-08-28 06:45:26 -07:00

45 lines
1.4 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedEmbedding
class Embedding(Module):
"""Implements a simple lookup table that maps each input integer to a
high-dimensional vector.
Typically used to embed discrete tokens for processing by neural networks.
Args:
num_embeddings (int): How many possible discrete tokens can we embed.
Usually called the vocabulary size.
dims (int): The dimensionality of the embeddings.
"""
def __init__(self, num_embeddings: int, dims: int):
super().__init__()
scale = math.sqrt(1 / dims)
self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
def _extra_repr(self):
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
def __call__(self, x):
return self.weight[x]
def as_linear(self, x):
"""
Call the embedding layer as a linear layer.
Use this for example when input embedding and output projection
weights are tied.
"""
return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits, mode)