Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | Z _ __init__() (array method) (Device method) (Dtype method) (DtypeCategory method) (Stream method) A abs() (array method) (in module mlx.core) AdaDelta (class in mlx.optimizers) Adafactor (class in mlx.optimizers) Adagrad (class in mlx.optimizers) Adam (class in mlx.optimizers) Adamax (class in mlx.optimizers) AdamW (class in mlx.optimizers) add() (in module mlx.core) ALiBi (class in mlx.nn) all() (array method) (in module mlx.core) allclose() (in module mlx.core) any() (array method) (in module mlx.core) apply() (Module method) apply_gradients() (Optimizer method) apply_to_modules() (Module method) arange() (in module mlx.core) arccos() (in module mlx.core) arccosh() (in module mlx.core) arcsin() (in module mlx.core) arcsinh() (in module mlx.core) arctan() (in module mlx.core) arctanh() (in module mlx.core) argmax() (array method) (in module mlx.core) argmin() (array method) (in module mlx.core) argpartition() (in module mlx.core) argsort() (in module mlx.core) array (class in mlx.core) array_equal() (in module mlx.core) astype() (array method) at (array property) atleast_1d() (in module mlx.core) atleast_2d() (in module mlx.core) atleast_3d() (in module mlx.core) AvgPool1d (class in mlx.nn) AvgPool2d (class in mlx.nn) B BatchNorm (class in mlx.nn) bernoulli() (in module mlx.core.random) binary_cross_entropy() (in module mlx.nn.losses) broadcast_to() (in module mlx.core) C categorical() (in module mlx.core.random) ceil() (in module mlx.core) children() (Module method) clip() (in module mlx.core) compile() (in module mlx.core) concatenate() (in module mlx.core) constant() (in module mlx.nn.init) Conv1d (class in mlx.nn) conv1d() (in module mlx.core) Conv2d (class in mlx.nn) conv2d() (in module mlx.core) conv_general() (in module mlx.core) convolve() (in module mlx.core) cos() (array method) (in module mlx.core) cosh() (in module mlx.core) cosine_decay() (in module mlx.optimizers) cosine_similarity_loss() (in module mlx.nn.losses) cross_entropy() (in module mlx.nn.losses) cummax() (array method) (in module mlx.core) cummin() (array method) (in module mlx.core) cumprod() (array method) (in module mlx.core) cumsum() (array method) (in module mlx.core) D default_device() (in module mlx.core) default_stream() (in module mlx.core) dequantize() (in module mlx.core) Device (class in mlx.core) diag() (array method) (in module mlx.core) diagonal() (array method) (in module mlx.core) disable_compile() (in module mlx.core) divide() (in module mlx.core) divmod() (in module mlx.core) Dropout (class in mlx.nn) Dropout2d (class in mlx.nn) Dropout3d (class in mlx.nn) dtype (array property) Dtype (class in mlx.core) DtypeCategory (class in mlx.core) E elu() (in module mlx.nn) Embedding (class in mlx.nn) enable_compile() (in module mlx.core) equal() (in module mlx.core) erf() (in module mlx.core) erfinv() (in module mlx.core) eval() (in module mlx.core) (Module method) exp() (array method) (in module mlx.core) expand_dims() (in module mlx.core) exponential_decay() (in module mlx.optimizers) eye() (in module mlx.core) F fft() (in module mlx.core.fft) fft2() (in module mlx.core.fft) fftn() (in module mlx.core.fft) filter_and_map() (Module method) flatten() (array method) (in module mlx.core) floor() (in module mlx.core) floor_divide() (in module mlx.core) freeze() (Module method) full() (in module mlx.core) G gaussian_nll_loss() (in module mlx.nn.losses) GELU (class in mlx.nn) gelu() (in module mlx.nn) gelu_approx() (in module mlx.nn) gelu_fast_approx() (in module mlx.nn) get_active_memory() (in module mlx.core.metal) get_cache_memory() (in module mlx.core.metal) get_peak_memory() (in module mlx.core.metal) glorot_normal() (in module mlx.nn.init) glorot_uniform() (in module mlx.nn.init) glu() (in module mlx.nn) grad() (in module mlx.core) greater() (in module mlx.core) greater_equal() (in module mlx.core) GroupNorm (class in mlx.nn) GRU (class in mlx.nn) gumbel() (in module mlx.core.random) H hardswish() (in module mlx.nn) he_normal() (in module mlx.nn.init) he_uniform() (in module mlx.nn.init) hinge_loss() (in module mlx.nn.losses) huber_loss() (in module mlx.nn.losses) I identity() (in module mlx.core) (in module mlx.nn.init) ifft() (in module mlx.core.fft) ifft2() (in module mlx.core.fft) ifftn() (in module mlx.core.fft) init() (Optimizer method) inner() (in module mlx.core) InstanceNorm (class in mlx.nn) irfft() (in module mlx.core.fft) irfft2() (in module mlx.core.fft) irfftn() (in module mlx.core.fft) is_available() (in module mlx.core.metal) isclose() (in module mlx.core) isinf() (in module mlx.core) isnan() (in module mlx.core) isneginf() (in module mlx.core) isposinf() (in module mlx.core) issubdtype() (in module mlx.core) item() (array method) itemsize (array property) J join_schedules() (in module mlx.optimizers) jvp() (in module mlx.core) K key() (in module mlx.core.random) kl_div_loss() (in module mlx.nn.losses) L l1_loss() (in module mlx.nn.losses) layer_norm() (in module mlx.core.fast) LayerNorm (class in mlx.nn) leaf_modules() (Module method) leaky_relu() (in module mlx.nn) less() (in module mlx.core) less_equal() (in module mlx.core) Linear (class in mlx.nn) linear_schedule() (in module mlx.optimizers) linspace() (in module mlx.core) Lion (class in mlx.optimizers) load() (in module mlx.core) load_weights() (Module method) log() (array method) (in module mlx.core) log10() (array method) (in module mlx.core) log1p() (array method) (in module mlx.core) log2() (array method) (in module mlx.core) log_cosh_loss() (in module mlx.nn.losses) log_sigmoid() (in module mlx.nn) log_softmax() (in module mlx.nn) logaddexp() (in module mlx.core) logical_and() (in module mlx.core) logical_not() (in module mlx.core) logical_or() (in module mlx.core) logsumexp() (array method) (in module mlx.core) LSTM (class in mlx.nn) M margin_ranking_loss() (in module mlx.nn.losses) matmul() (in module mlx.core) max() (array method) (in module mlx.core) maximum() (in module mlx.core) MaxPool1d (class in mlx.nn) MaxPool2d (class in mlx.nn) mean() (array method) (in module mlx.core) min() (array method) (in module mlx.core) minimum() (in module mlx.core) Mish (class in mlx.nn) mish() (in module mlx.nn) Module (class in mlx.nn) modules() (Module method) moveaxis() (array method) (in module mlx.core) mse_loss() (in module mlx.nn.losses) MultiHeadAttention (class in mlx.nn) multiply() (in module mlx.core) N named_modules() (Module method) nbytes (array property) ndim (array property) negative() (in module mlx.core) new_stream() (in module mlx.core) nll_loss() (in module mlx.nn.losses) norm() (in module mlx.core.linalg) normal() (in module mlx.core.random) (in module mlx.nn.init) O ones() (in module mlx.core) ones_like() (in module mlx.core) Optimizer (class in mlx.optimizers) outer() (in module mlx.core) P pad() (in module mlx.core) parameters() (Module method) partition() (in module mlx.core) PReLU (class in mlx.nn) prelu() (in module mlx.nn) prod() (array method) (in module mlx.core) Q qr() (in module mlx.core.linalg) quantize() (in module mlx.core) quantized_matmul() (in module mlx.core) QuantizedLinear (class in mlx.nn) R randint() (in module mlx.core.random) reciprocal() (array method) (in module mlx.core) ReLU (class in mlx.nn) relu() (in module mlx.nn) relu6() (in module mlx.nn) repeat() (in module mlx.core) reshape() (array method) (in module mlx.core) rfft() (in module mlx.core.fft) rfft2() (in module mlx.core.fft) rfftn() (in module mlx.core.fft) rms_norm() (in module mlx.core.fast) RMSNorm (class in mlx.nn) RMSprop (class in mlx.optimizers) RNN (class in mlx.nn) RoPE (class in mlx.nn) rope() (in module mlx.core.fast) round() (array method) (in module mlx.core) rsqrt() (array method) (in module mlx.core) S save() (in module mlx.core) save_gguf() (in module mlx.core) save_safetensors() (in module mlx.core) save_weights() (Module method) savez() (in module mlx.core) savez_compressed() (in module mlx.core) scaled_dot_product_attention() (in module mlx.core.fast) seed() (in module mlx.core.random) SELU (class in mlx.nn) selu() (in module mlx.nn) Sequential (class in mlx.nn) set_cache_limit() (in module mlx.core.metal) set_default_device() (in module mlx.core) set_default_stream() (in module mlx.core) set_dtype() (Module method) set_memory_limit() (in module mlx.core.metal) SGD (class in mlx.optimizers) shape (array property) sigmoid() (in module mlx.core) (in module mlx.nn) sign() (in module mlx.core) SiLU (class in mlx.nn) silu() (in module mlx.nn) sin() (array method) (in module mlx.core) sinh() (in module mlx.core) SinusoidalPositionalEncoding (class in mlx.nn) size (array property) smooth_l1_loss() (in module mlx.nn.losses) softmax() (in module mlx.core) (in module mlx.nn) softplus() (in module mlx.nn) Softshrink (class in mlx.nn) softshrink() (in module mlx.nn) sort() (in module mlx.core) split() (array method) (in module mlx.core) (in module mlx.core.random) sqrt() (array method) (in module mlx.core) square() (array method) (in module mlx.core) squeeze() (array method) (in module mlx.core) stack() (in module mlx.core) state (Module property) (Optimizer property) Step (class in mlx.nn) step() (in module mlx.nn) step_decay() (in module mlx.optimizers) stop_gradient() (in module mlx.core) Stream (class in mlx.core) stream() (in module mlx.core) subtract() (in module mlx.core) sum() (array method) (in module mlx.core) swapaxes() (array method) (in module mlx.core) T T (array property) take() (in module mlx.core) take_along_axis() (in module mlx.core) tan() (in module mlx.core) tanh() (in module mlx.core) (in module mlx.nn) tensordot() (in module mlx.core) tile() (in module mlx.core) tolist() (array method) topk() (in module mlx.core) train() (Module method) trainable_parameters() (Module method) training (Module property) Transformer (class in mlx.nn) transpose() (array method) (in module mlx.core) tree_flatten() (in module mlx.utils) tree_map() (in module mlx.utils) tree_unflatten() (in module mlx.utils) tri() (in module mlx.core) tril() (in module mlx.core) triplet_loss() (in module mlx.nn.losses) triu() (in module mlx.core) truncated_normal() (in module mlx.core.random) U unfreeze() (Module method) uniform() (in module mlx.core.random) (in module mlx.nn.init) update() (Module method) (Optimizer method) update_modules() (Module method) Upsample (class in mlx.nn) V value_and_grad() (in module mlx.core) (in module mlx.nn) var() (array method) (in module mlx.core) vjp() (in module mlx.core) vmap() (in module mlx.core) W where() (in module mlx.core) Z zeros() (in module mlx.core) zeros_like() (in module mlx.core)