mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* add sdpa with sinks * fix 2 pass * fix matrix sdpa * fix perf regression * add to cuda (#2580)
103 lines
2.7 KiB
C++
103 lines
2.7 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <optional>
|
|
#include <variant>
|
|
|
|
#include "mlx/utils.h"
|
|
|
|
namespace mlx::core::fast {
|
|
|
|
array rms_norm(
|
|
const array& x,
|
|
const std::optional<array>& weight,
|
|
float eps,
|
|
StreamOrDevice s = {});
|
|
|
|
array layer_norm(
|
|
const array& x,
|
|
const std::optional<array>& weight,
|
|
const std::optional<array>& bias,
|
|
float eps,
|
|
StreamOrDevice s = {});
|
|
|
|
array rope(
|
|
const array& x,
|
|
int dims,
|
|
bool traditional,
|
|
std::optional<float> base,
|
|
float scale,
|
|
int offset,
|
|
const std::optional<array>& freqs = std::nullopt,
|
|
StreamOrDevice s = {});
|
|
|
|
array rope(
|
|
const array& x,
|
|
int dims,
|
|
bool traditional,
|
|
std::optional<float> base,
|
|
float scale,
|
|
const array& offset,
|
|
const std::optional<array>& freqs = std::nullopt,
|
|
StreamOrDevice s = {});
|
|
|
|
/** Computes: O = softmax(Q @ K.T) @ V **/
|
|
array scaled_dot_product_attention(
|
|
const array& queries,
|
|
const array& keys,
|
|
const array& values,
|
|
const float scale,
|
|
const std::string& mask_mode = "",
|
|
const std::vector<array>& mask_arrs = {},
|
|
const std::optional<array>& sinks = {},
|
|
StreamOrDevice s = {});
|
|
|
|
using TemplateArg = std::variant<int, bool, Dtype>;
|
|
using ScalarArg = std::variant<bool, int, float>;
|
|
|
|
using CustomKernelFunction = std::function<std::vector<array>(
|
|
const std::vector<array>&,
|
|
const std::vector<Shape>&,
|
|
const std::vector<Dtype>&,
|
|
std::tuple<int, int, int>,
|
|
std::tuple<int, int, int>,
|
|
std::vector<std::pair<std::string, TemplateArg>>,
|
|
std::optional<float>,
|
|
bool,
|
|
StreamOrDevice)>;
|
|
|
|
CustomKernelFunction metal_kernel(
|
|
const std::string& name,
|
|
const std::vector<std::string>& input_names,
|
|
const std::vector<std::string>& output_names,
|
|
const std::string& source,
|
|
const std::string& header = "",
|
|
bool ensure_row_contiguous = true,
|
|
bool atomic_outputs = false);
|
|
|
|
CustomKernelFunction cuda_kernel(
|
|
const std::string& name,
|
|
const std::vector<std::string>& input_names,
|
|
const std::vector<std::string>& output_names,
|
|
const std::string& source,
|
|
const std::string& header = "",
|
|
bool ensure_row_contiguous = true,
|
|
int shared_memory = 0);
|
|
|
|
std::vector<array> precompiled_cuda_kernel(
|
|
const std::string& name,
|
|
const std::string& compiled_source,
|
|
const std::vector<array>& inputs,
|
|
const std::vector<Shape>& output_shapes,
|
|
const std::vector<Dtype>& output_dtypes,
|
|
const std::vector<ScalarArg>& scalars,
|
|
std::tuple<int, int, int> grid,
|
|
std::tuple<int, int, int> threadgroup,
|
|
int shared_memory = 0,
|
|
std::optional<float> init_value = std::nullopt,
|
|
bool ensure_row_contiguous = false,
|
|
StreamOrDevice s = {});
|
|
|
|
} // namespace mlx::core::fast
|