Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | Z _ __init__() (mlx.core.array method) (mlx.core.Device method) (mlx.core.Dtype method) (mlx.core.Stream method) _cos_sin_theta_key (mlx.nn.RoPE attribute) _cos_sin_theta_value (mlx.nn.RoPE attribute) A abs() (in module mlx.core) (mlx.core.array method) AdaDelta (class in mlx.optimizers) Adagrad (class in mlx.optimizers) Adam (class in mlx.optimizers) Adamax (class in mlx.optimizers) AdamW (class in mlx.optimizers) add() (in module mlx.core) ALiBi (class in mlx.nn) all() (in module mlx.core) (mlx.core.array method) allclose() (in module mlx.core) any() (in module mlx.core) (mlx.core.array method) apply() (mlx.nn.Module method) apply_to_modules() (mlx.nn.Module method) arange() (in module mlx.core) arccos() (in module mlx.core) arccosh() (in module mlx.core) arcsin() (in module mlx.core) arcsinh() (in module mlx.core) arctan() (in module mlx.core) arctanh() (in module mlx.core) argmax() (in module mlx.core) (mlx.core.array method) argmin() (in module mlx.core) (mlx.core.array method) argpartition() (in module mlx.core) argsort() (in module mlx.core) array (class in mlx.core) array_equal() (in module mlx.core) astype() (mlx.core.array method) B BatchNorm (class in mlx.nn) bernoulli() (in module mlx.core.random) binary_cross_entropy (class in mlx.nn.losses) broadcast_to() (in module mlx.core) C categorical() (in module mlx.core.random) ceil() (in module mlx.core) children() (mlx.nn.Module method) clip() (in module mlx.core) concatenate() (in module mlx.core) Conv1d (class in mlx.nn) conv1d() (in module mlx.core) Conv2d (class in mlx.nn) conv2d() (in module mlx.core) convolve() (in module mlx.core) cos() (in module mlx.core) (mlx.core.array method) cosh() (in module mlx.core) cosine_similarity_loss (class in mlx.nn.losses) cross_entropy (class in mlx.nn.losses) D default_device() (in module mlx.core) default_stream() (in module mlx.core) dequantize() (in module mlx.core) Device (class in mlx.core) divide() (in module mlx.core) divmod() (in module mlx.core) Dropout (class in mlx.nn) Dropout2d (class in mlx.nn) Dropout3d (class in mlx.nn) Dtype (class in mlx.core) dtype (mlx.core.array property) E Embedding (class in mlx.nn) equal() (in module mlx.core) erf() (in module mlx.core) erfinv() (in module mlx.core) eval() (in module mlx.core) (mlx.nn.Module method) exp() (in module mlx.core) (mlx.core.array method) expand_dims() (in module mlx.core) eye() (in module mlx.core) F fft() (in module mlx.core.fft) fft2() (in module mlx.core.fft) fftn() (in module mlx.core.fft) filter_and_map() (mlx.nn.Module method) flatten() (in module mlx.core) floor() (in module mlx.core) floor_divide() (in module mlx.core) freeze() (mlx.nn.Module method) full() (in module mlx.core) G GELU (class in mlx.nn) gelu (class in mlx.nn) gelu_approx (class in mlx.nn) gelu_fast_approx (class in mlx.nn) grad() (in module mlx.core) greater() (in module mlx.core) greater_equal() (in module mlx.core) GroupNorm (class in mlx.nn) gumbel() (in module mlx.core.random) H hinge_loss (class in mlx.nn.losses) huber_loss (class in mlx.nn.losses) I identity() (in module mlx.core) ifft() (in module mlx.core.fft) ifft2() (in module mlx.core.fft) ifftn() (in module mlx.core.fft) inner() (in module mlx.core) InstanceNorm (class in mlx.nn) irfft() (in module mlx.core.fft) irfft2() (in module mlx.core.fft) irfftn() (in module mlx.core.fft) item() (mlx.core.array method) J jvp() (in module mlx.core) K key() (in module mlx.core.random) kl_div_loss (class in mlx.nn.losses) L l1_loss (class in mlx.nn.losses) LayerNorm (class in mlx.nn) leaf_modules() (mlx.nn.Module method) less() (in module mlx.core) less_equal() (in module mlx.core) Linear (class in mlx.nn) linspace() (in module mlx.core) Lion (class in mlx.optimizers) load() (in module mlx.core) load_weights() (mlx.nn.Module method) log() (in module mlx.core) (mlx.core.array method) log10() (in module mlx.core) log1p() (in module mlx.core) (mlx.core.array method) log2() (in module mlx.core) log_cosh_loss (class in mlx.nn.losses) logaddexp() (in module mlx.core) logical_and() (in module mlx.core) logical_not() (in module mlx.core) logical_or() (in module mlx.core) logsumexp() (in module mlx.core) (mlx.core.array method) M matmul() (in module mlx.core) max() (in module mlx.core) (mlx.core.array method) maximum() (in module mlx.core) mean() (in module mlx.core) (mlx.core.array method) min() (in module mlx.core) (mlx.core.array method) minimum() (in module mlx.core) Mish (class in mlx.nn) mish (class in mlx.nn) Module (class in mlx.nn) modules() (mlx.nn.Module method) moveaxis() (in module mlx.core) mse_loss (class in mlx.nn.losses) MultiHeadAttention (class in mlx.nn) multiply() (in module mlx.core) N named_modules() (mlx.nn.Module method) ndim (mlx.core.array property) negative() (in module mlx.core) new_stream() (in module mlx.core) nll_loss (class in mlx.nn.losses) norm() (in module mlx.core.linalg) normal() (in module mlx.core.random) O ones() (in module mlx.core) ones_like() (in module mlx.core) Optimizer (class in mlx.optimizers) OptimizerState (class in mlx.optimizers) outer() (in module mlx.core) P pad() (in module mlx.core) parameters() (mlx.nn.Module method) partition() (in module mlx.core) PReLU (class in mlx.nn) prelu (class in mlx.nn) prod() (in module mlx.core) (mlx.core.array method) Q quantize() (in module mlx.core) quantized_matmul() (in module mlx.core) QuantizedLinear (class in mlx.nn) R randint() (in module mlx.core.random) reciprocal() (in module mlx.core) (mlx.core.array method) ReLU (class in mlx.nn) relu (class in mlx.nn) repeat() (in module mlx.core) reshape() (in module mlx.core) (mlx.core.array method) rfft() (in module mlx.core.fft) rfft2() (in module mlx.core.fft) rfftn() (in module mlx.core.fft) RMSNorm (class in mlx.nn) RMSprop (class in mlx.optimizers) RoPE (class in mlx.nn) round() (in module mlx.core) (mlx.core.array method) rsqrt() (in module mlx.core) (mlx.core.array method) S save() (in module mlx.core) save_gguf() (in module mlx.core) save_safetensors() (in module mlx.core) save_weights() (mlx.nn.Module method) savez() (in module mlx.core) savez_compressed() (in module mlx.core) seed() (in module mlx.core.random) SELU (class in mlx.nn) selu (class in mlx.nn) Sequential (class in mlx.nn) set_default_device() (in module mlx.core) set_default_stream() (in module mlx.core) SGD (class in mlx.optimizers) shape (mlx.core.array property) sigmoid() (in module mlx.core) sign() (in module mlx.core) SiLU (class in mlx.nn) silu (class in mlx.nn) simplify() (in module mlx.core) sin() (in module mlx.core) (mlx.core.array method) sinh() (in module mlx.core) SinusoidalPositionalEncoding (class in mlx.nn) size (mlx.core.array property) smooth_l1_loss (class in mlx.nn.losses) softmax() (in module mlx.core) sort() (in module mlx.core) split() (in module mlx.core) (in module mlx.core.random) (mlx.core.array method) sqrt() (in module mlx.core) (mlx.core.array method) square() (in module mlx.core) (mlx.core.array method) squeeze() (in module mlx.core) stack() (in module mlx.core) state (mlx.optimizers.Optimizer attribute) Step (class in mlx.nn) step (class in mlx.nn) stop_gradient() (in module mlx.core) Stream (class in mlx.core) subtract() (in module mlx.core) sum() (in module mlx.core) (mlx.core.array method) swapaxes() (in module mlx.core) T T (mlx.core.array property) take() (in module mlx.core) take_along_axis() (in module mlx.core) tan() (in module mlx.core) tanh() (in module mlx.core) tensordot() (in module mlx.core) tolist() (mlx.core.array method) train() (mlx.nn.Module method) trainable_parameters() (mlx.nn.Module method) training (mlx.nn.Module property) Transformer (class in mlx.nn) transpose() (in module mlx.core) (mlx.core.array method) tree_flatten() (in module mlx.utils) tree_map() (in module mlx.utils) tree_unflatten() (in module mlx.utils) tri() (in module mlx.core) tril() (in module mlx.core) triplet_loss (class in mlx.nn.losses) triu() (in module mlx.core) truncated_normal() (in module mlx.core.random) U unfreeze() (mlx.nn.Module method) uniform() (in module mlx.core.random) update() (mlx.nn.Module method) update_modules() (mlx.nn.Module method) V value_and_grad() (in module mlx.core) (in module mlx.nn) var() (in module mlx.core) (mlx.core.array method) vjp() (in module mlx.core) vmap() (in module mlx.core) W where() (in module mlx.core) Z zeros() (in module mlx.core) zeros_like() (in module mlx.core)