mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Separate fast ops and primitives (#699)
This commit is contained in:
parent
bf7cd29970
commit
c3965fc5ee
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/fast.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/fast.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/fast.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/fast.h"
|
#include "mlx/fast.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
66
mlx/fast.h
66
mlx/fast.h
@ -2,40 +2,10 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/ops.h"
|
#include "mlx/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
// Custom primitive accepts a fallback function which it uses for
|
|
||||||
// transformations. Transformations are virtual so that derived classes may to
|
|
||||||
// override the default behavior
|
|
||||||
class Custom : public Primitive {
|
|
||||||
public:
|
|
||||||
explicit Custom(
|
|
||||||
Stream stream,
|
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback)
|
|
||||||
: Primitive(stream), fallback_(fallback){};
|
|
||||||
|
|
||||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<int>& axes) override;
|
|
||||||
|
|
||||||
virtual std::vector<array> jvp(
|
|
||||||
const std::vector<array>& primals,
|
|
||||||
const std::vector<array>& tangents,
|
|
||||||
const std::vector<int>& argnums) override;
|
|
||||||
|
|
||||||
virtual std::vector<array> vjp(
|
|
||||||
const std::vector<array>& primals,
|
|
||||||
const std::vector<array>& cotangents,
|
|
||||||
const std::vector<int>& argnums,
|
|
||||||
const std::vector<array>& outputs) override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
||||||
};
|
|
||||||
|
|
||||||
array rope(
|
array rope(
|
||||||
const array& x,
|
const array& x,
|
||||||
int dims,
|
int dims,
|
||||||
@ -45,38 +15,4 @@ array rope(
|
|||||||
int offset,
|
int offset,
|
||||||
StreamOrDevice s /* = {} */);
|
StreamOrDevice s /* = {} */);
|
||||||
|
|
||||||
class RoPE : public Custom {
|
|
||||||
public:
|
|
||||||
RoPE(
|
|
||||||
Stream stream,
|
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
||||||
int dims,
|
|
||||||
bool traditional,
|
|
||||||
float base,
|
|
||||||
float scale,
|
|
||||||
int offset)
|
|
||||||
: Custom(stream, fallback),
|
|
||||||
dims_(dims),
|
|
||||||
traditional_(traditional),
|
|
||||||
base_(base),
|
|
||||||
scale_(scale),
|
|
||||||
offset_(offset){};
|
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
||||||
override;
|
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
||||||
override;
|
|
||||||
|
|
||||||
DEFINE_PRINT(RoPE)
|
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
||||||
int dims_;
|
|
||||||
bool traditional_;
|
|
||||||
float base_;
|
|
||||||
float scale_;
|
|
||||||
int offset_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
68
mlx/fast_primitives.h
Normal file
68
mlx/fast_primitives.h
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
// Custom primitive accepts a fallback function which it uses for
|
||||||
|
// transformations. Transformations are virtual so that derived classes may to
|
||||||
|
// override the default behavior
|
||||||
|
class Custom : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit Custom(
|
||||||
|
Stream stream,
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback)
|
||||||
|
: Primitive(stream), fallback_(fallback){};
|
||||||
|
|
||||||
|
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
virtual std::vector<array> jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) override;
|
||||||
|
|
||||||
|
virtual std::vector<array> vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class RoPE : public Custom {
|
||||||
|
public:
|
||||||
|
RoPE(
|
||||||
|
Stream stream,
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
|
int dims,
|
||||||
|
bool traditional,
|
||||||
|
float base,
|
||||||
|
float scale,
|
||||||
|
int offset)
|
||||||
|
: Custom(stream, fallback),
|
||||||
|
dims_(dims),
|
||||||
|
traditional_(traditional),
|
||||||
|
base_(base),
|
||||||
|
scale_(scale),
|
||||||
|
offset_(offset){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(RoPE)
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
|
int dims_;
|
||||||
|
bool traditional_;
|
||||||
|
float base_;
|
||||||
|
float scale_;
|
||||||
|
int offset_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::fast
|
Loading…
Reference in New Issue
Block a user