Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | Z _ __init__() (mlx.core.array method) (mlx.core.Device method) (mlx.core.Dtype method) (mlx.core.Stream method) A abs() (in module mlx.core) (mlx.core.array method) AdaDelta (class in mlx.optimizers) Adafactor (class in mlx.optimizers) Adagrad (class in mlx.optimizers) Adam (class in mlx.optimizers) Adamax (class in mlx.optimizers) AdamW (class in mlx.optimizers) add() (in module mlx.core) ALiBi (class in mlx.nn) all() (in module mlx.core) (mlx.core.array method) allclose() (in module mlx.core) any() (in module mlx.core) (mlx.core.array method) apply() (mlx.nn.Module method) apply_gradients() (mlx.optimizers.Optimizer method) apply_to_modules() (mlx.nn.Module method) arange() (in module mlx.core) arccos() (in module mlx.core) arccosh() (in module mlx.core) arcsin() (in module mlx.core) arcsinh() (in module mlx.core) arctan() (in module mlx.core) arctanh() (in module mlx.core) argmax() (in module mlx.core) (mlx.core.array method) argmin() (in module mlx.core) (mlx.core.array method) argpartition() (in module mlx.core) argsort() (in module mlx.core) array (class in mlx.core) array_equal() (in module mlx.core) astype() (mlx.core.array method) atleast_1d() (in module mlx.core) atleast_2d() (in module mlx.core) atleast_3d() (in module mlx.core) AvgPool1d (class in mlx.nn) AvgPool2d (class in mlx.nn) B BatchNorm (class in mlx.nn) bernoulli() (in module mlx.core.random) binary_cross_entropy() (in module mlx.nn.losses) broadcast_to() (in module mlx.core) C categorical() (in module mlx.core.random) ceil() (in module mlx.core) children() (mlx.nn.Module method) clip() (in module mlx.core) compile() (in module mlx.core) concatenate() (in module mlx.core) constant() (in module mlx.nn.init) Conv1d (class in mlx.nn) conv1d() (in module mlx.core) Conv2d (class in mlx.nn) conv2d() (in module mlx.core) conv_general() (in module mlx.core) convolve() (in module mlx.core) cos() (in module mlx.core) (mlx.core.array method) cosh() (in module mlx.core) cosine_decay() (in module mlx.optimizers) cosine_similarity_loss() (in module mlx.nn.losses) cross_entropy() (in module mlx.nn.losses) D default_device() (in module mlx.core) default_stream() (in module mlx.core) dequantize() (in module mlx.core) Device (class in mlx.core) diag() (in module mlx.core) diagonal() (in module mlx.core) disable_compile() (in module mlx.core) divide() (in module mlx.core) divmod() (in module mlx.core) Dropout (class in mlx.nn) Dropout2d (class in mlx.nn) Dropout3d (class in mlx.nn) Dtype (class in mlx.core) dtype (mlx.core.array property) E elu() (in module mlx.nn) Embedding (class in mlx.nn) enable_compile() (in module mlx.core) equal() (in module mlx.core) erf() (in module mlx.core) erfinv() (in module mlx.core) eval() (in module mlx.core) (mlx.nn.Module method) exp() (in module mlx.core) (mlx.core.array method) expand_dims() (in module mlx.core) exponential_decay() (in module mlx.optimizers) eye() (in module mlx.core) F fft() (in module mlx.core.fft) fft2() (in module mlx.core.fft) fftn() (in module mlx.core.fft) filter_and_map() (mlx.nn.Module method) flatten() (in module mlx.core) floor() (in module mlx.core) floor_divide() (in module mlx.core) freeze() (mlx.nn.Module method) full() (in module mlx.core) G gaussian_nll_loss() (in module mlx.nn.losses) GELU (class in mlx.nn) gelu() (in module mlx.nn) gelu_approx() (in module mlx.nn) gelu_fast_approx() (in module mlx.nn) get_active_memory() (in module mlx.core.metal) get_cache_memory() (in module mlx.core.metal) get_peak_memory() (in module mlx.core.metal) glorot_normal() (in module mlx.nn.init) glorot_uniform() (in module mlx.nn.init) glu() (in module mlx.nn) grad() (in module mlx.core) greater() (in module mlx.core) greater_equal() (in module mlx.core) GroupNorm (class in mlx.nn) GRU (class in mlx.nn) gumbel() (in module mlx.core.random) H hardswish() (in module mlx.nn) he_normal() (in module mlx.nn.init) he_uniform() (in module mlx.nn.init) hinge_loss() (in module mlx.nn.losses) huber_loss() (in module mlx.nn.losses) I identity() (in module mlx.core) (in module mlx.nn.init) ifft() (in module mlx.core.fft) ifft2() (in module mlx.core.fft) ifftn() (in module mlx.core.fft) init() (mlx.optimizers.Optimizer method) inner() (in module mlx.core) InstanceNorm (class in mlx.nn) irfft() (in module mlx.core.fft) irfft2() (in module mlx.core.fft) irfftn() (in module mlx.core.fft) is_available() (in module mlx.core.metal) isclose() (in module mlx.core) isinf() (in module mlx.core) isnan() (in module mlx.core) isneginf() (in module mlx.core) isposinf() (in module mlx.core) item() (mlx.core.array method) J join_schedules() (in module mlx.optimizers) jvp() (in module mlx.core) K key() (in module mlx.core.random) kl_div_loss() (in module mlx.nn.losses) L l1_loss() (in module mlx.nn.losses) LayerNorm (class in mlx.nn) leaf_modules() (mlx.nn.Module method) leaky_relu() (in module mlx.nn) less() (in module mlx.core) less_equal() (in module mlx.core) Linear (class in mlx.nn) linear_schedule() (in module mlx.optimizers) linspace() (in module mlx.core) Lion (class in mlx.optimizers) load() (in module mlx.core) load_weights() (mlx.nn.Module method) log() (in module mlx.core) (mlx.core.array method) log10() (in module mlx.core) log1p() (in module mlx.core) (mlx.core.array method) log2() (in module mlx.core) log_cosh_loss() (in module mlx.nn.losses) log_sigmoid() (in module mlx.nn) log_softmax() (in module mlx.nn) logaddexp() (in module mlx.core) logical_and() (in module mlx.core) logical_not() (in module mlx.core) logical_or() (in module mlx.core) logsumexp() (in module mlx.core) (mlx.core.array method) LSTM (class in mlx.nn) M margin_ranking_loss() (in module mlx.nn.losses) matmul() (in module mlx.core) max() (in module mlx.core) (mlx.core.array method) maximum() (in module mlx.core) MaxPool1d (class in mlx.nn) MaxPool2d (class in mlx.nn) mean() (in module mlx.core) (mlx.core.array method) min() (in module mlx.core) (mlx.core.array method) minimum() (in module mlx.core) Mish (class in mlx.nn) mish() (in module mlx.nn) Module (class in mlx.nn) modules() (mlx.nn.Module method) moveaxis() (in module mlx.core) mse_loss() (in module mlx.nn.losses) MultiHeadAttention (class in mlx.nn) multiply() (in module mlx.core) N named_modules() (mlx.nn.Module method) ndim (mlx.core.array property) negative() (in module mlx.core) new_stream() (in module mlx.core) nll_loss() (in module mlx.nn.losses) norm() (in module mlx.core.linalg) normal() (in module mlx.core.random) (in module mlx.nn.init) O ones() (in module mlx.core) ones_like() (in module mlx.core) Optimizer (class in mlx.optimizers) outer() (in module mlx.core) P pad() (in module mlx.core) parameters() (mlx.nn.Module method) partition() (in module mlx.core) PReLU (class in mlx.nn) prelu() (in module mlx.nn) prod() (in module mlx.core) (mlx.core.array method) Q qr() (in module mlx.core.linalg) quantize() (in module mlx.core) quantized_matmul() (in module mlx.core) QuantizedLinear (class in mlx.nn) R randint() (in module mlx.core.random) reciprocal() (in module mlx.core) (mlx.core.array method) ReLU (class in mlx.nn) relu() (in module mlx.nn) relu6() (in module mlx.nn) repeat() (in module mlx.core) reshape() (in module mlx.core) (mlx.core.array method) rfft() (in module mlx.core.fft) rfft2() (in module mlx.core.fft) rfftn() (in module mlx.core.fft) RMSNorm (class in mlx.nn) RMSprop (class in mlx.optimizers) RNN (class in mlx.nn) RoPE (class in mlx.nn) round() (in module mlx.core) (mlx.core.array method) rsqrt() (in module mlx.core) (mlx.core.array method) S save() (in module mlx.core) save_gguf() (in module mlx.core) save_safetensors() (in module mlx.core) save_weights() (mlx.nn.Module method) savez() (in module mlx.core) savez_compressed() (in module mlx.core) seed() (in module mlx.core.random) SELU (class in mlx.nn) selu() (in module mlx.nn) Sequential (class in mlx.nn) set_cache_limit() (in module mlx.core.metal) set_default_device() (in module mlx.core) set_default_stream() (in module mlx.core) set_memory_limit() (in module mlx.core.metal) SGD (class in mlx.optimizers) shape (mlx.core.array property) sigmoid() (in module mlx.core) (in module mlx.nn) sign() (in module mlx.core) SiLU (class in mlx.nn) silu() (in module mlx.nn) sin() (in module mlx.core) (mlx.core.array method) sinh() (in module mlx.core) SinusoidalPositionalEncoding (class in mlx.nn) size (mlx.core.array property) smooth_l1_loss() (in module mlx.nn.losses) softmax() (in module mlx.core) (in module mlx.nn) softplus() (in module mlx.nn) Softshrink (class in mlx.nn) softshrink() (in module mlx.nn) sort() (in module mlx.core) split() (in module mlx.core) (in module mlx.core.random) (mlx.core.array method) sqrt() (in module mlx.core) (mlx.core.array method) square() (in module mlx.core) (mlx.core.array method) squeeze() (in module mlx.core) stack() (in module mlx.core) state (mlx.nn.Module property) (mlx.optimizers.Optimizer property) Step (class in mlx.nn) step() (in module mlx.nn) step_decay() (in module mlx.optimizers) stop_gradient() (in module mlx.core) Stream (class in mlx.core) stream() (in module mlx.core) subtract() (in module mlx.core) sum() (in module mlx.core) (mlx.core.array method) swapaxes() (in module mlx.core) T T (mlx.core.array property) take() (in module mlx.core) take_along_axis() (in module mlx.core) tan() (in module mlx.core) tanh() (in module mlx.core) (in module mlx.nn) tensordot() (in module mlx.core) tile() (in module mlx.core) tolist() (mlx.core.array method) topk() (in module mlx.core) train() (mlx.nn.Module method) trainable_parameters() (mlx.nn.Module method) training (mlx.nn.Module property) Transformer (class in mlx.nn) transpose() (in module mlx.core) (mlx.core.array method) tree_flatten() (in module mlx.utils) tree_map() (in module mlx.utils) tree_unflatten() (in module mlx.utils) tri() (in module mlx.core) tril() (in module mlx.core) triplet_loss() (in module mlx.nn.losses) triu() (in module mlx.core) truncated_normal() (in module mlx.core.random) U unfreeze() (mlx.nn.Module method) uniform() (in module mlx.core.random) (in module mlx.nn.init) update() (mlx.nn.Module method) (mlx.optimizers.Optimizer method) update_modules() (mlx.nn.Module method) Upsample (class in mlx.nn) V value_and_grad() (in module mlx.core) (in module mlx.nn) var() (in module mlx.core) (mlx.core.array method) vjp() (in module mlx.core) vmap() (in module mlx.core) W where() (in module mlx.core) Z zeros() (in module mlx.core) zeros_like() (in module mlx.core)