mlx/mlx/fast.h
Awni Hannun bb1b76d9dc
RoPE with frequencies as optional input (#1337)
* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
2024-08-19 18:30:50 -07:00

67 lines
1.4 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <optional>
#include "mlx/utils.h"
namespace mlx::core::fast {
array rms_norm(
const array& x,
const array& weight,
float eps,
StreamOrDevice s = {});
array layer_norm(
const array& x,
const std::optional<array>& weight,
const std::optional<array>& bias,
float eps,
StreamOrDevice s = {});
array rope(
const array& x,
int dims,
bool traditional,
std::optional<float> base,
float scale,
int offset,
const std::optional<array>& freqs = std::nullopt,
StreamOrDevice s = {});
/** Computes: O = softmax(Q @ K.T) @ V **/
array scaled_dot_product_attention(
const array& queries,
const array& keys,
const array& values,
const float scale,
const std::optional<array>& mask = std::nullopt,
const std::optional<int>& memory_efficient_threshold = std::nullopt,
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(
const array& w,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
array affine_quantize(
const array& w,
const array& scales,
const array& biases,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
array affine_dequantize(
const array& w,
const array& scales,
const array& biases,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
} // namespace mlx::core::fast