Separate fast ops and primitives (#699)

This commit is contained in:
Awni Hannun 2024-02-16 19:16:39 -08:00 committed by GitHub
parent bf7cd29970
commit c3965fc5ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 74 additions and 70 deletions

View File

@ -1,7 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast.h"
#include "mlx/primitives.h"
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {

View File

@ -1,8 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/primitives.h"
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/primitives.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \

View File

@ -1,6 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
#include "mlx/transforms.h"
namespace mlx::core::fast {

View File

@ -2,40 +2,10 @@
#pragma once
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
// Custom primitive accepts a fallback function which it uses for
// transformations. Transformations are virtual so that derived classes may to
// override the default behavior
class Custom : public Primitive {
public:
explicit Custom(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback)
: Primitive(stream), fallback_(fallback){};
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
virtual std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
virtual std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
};
array rope(
const array& x,
int dims,
@ -45,38 +15,4 @@ array rope(
int offset,
StreamOrDevice s /* = {} */);
class RoPE : public Custom {
public:
RoPE(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int dims,
bool traditional,
float base,
float scale,
int offset)
: Custom(stream, fallback),
dims_(dims),
traditional_(traditional),
base_(base),
scale_(scale),
offset_(offset){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RoPE)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int dims_;
bool traditional_;
float base_;
float scale_;
int offset_;
};
} // namespace mlx::core::fast

68
mlx/fast_primitives.h Normal file
View File

@ -0,0 +1,68 @@
#include "mlx/primitives.h"
namespace mlx::core::fast {
// Custom primitive accepts a fallback function which it uses for
// transformations. Transformations are virtual so that derived classes may to
// override the default behavior
class Custom : public Primitive {
public:
explicit Custom(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback)
: Primitive(stream), fallback_(fallback){};
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
virtual std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
virtual std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
};
class RoPE : public Custom {
public:
RoPE(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int dims,
bool traditional,
float base,
float scale,
int offset)
: Custom(stream, fallback),
dims_(dims),
traditional_(traditional),
base_(base),
scale_(scale),
offset_(offset){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RoPE)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int dims_;
bool traditional_;
float base_;
float scale_;
int offset_;
};
} // namespace mlx::core::fast