mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
eff0e31f00
...
193cdcd81a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
193cdcd81a | ||
|
|
d8ceae7b77 |
@@ -332,9 +332,17 @@ bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
|||||||
for (const auto& node : nodes) {
|
for (const auto& node : nodes) {
|
||||||
cudaGraphNodeType type;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||||
if (type != cudaGraphNodeTypeKernel) {
|
if (type == cudaGraphNodeTypeGraph) {
|
||||||
|
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||||
|
if (num_nodes > 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
cudaGraph_t child;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||||
|
return is_graph_updatable(child, cluster_dim_x);
|
||||||
|
} else if (type != cudaGraphNodeTypeKernel) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
cudaLaunchAttributeValue cluster_dim;
|
cudaLaunchAttributeValue cluster_dim;
|
||||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||||
@@ -348,6 +356,7 @@ bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
|||||||
}
|
}
|
||||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,7 +371,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
|||||||
}
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
int cluster_dim_x = 0;
|
int cluster_dim_x = 0;
|
||||||
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
|
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(
|
insert_graph_dependencies(
|
||||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||||
|
|||||||
@@ -3846,6 +3846,62 @@ std::vector<array> Reduce::vjp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Reduce::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
auto in = primals[0];
|
||||||
|
auto s = stream();
|
||||||
|
|
||||||
|
auto grad_op = [&s, reduce_type = reduce_type_](
|
||||||
|
const array& x, const array& tan, int axis) {
|
||||||
|
if (reduce_type == Reduce::Min) {
|
||||||
|
auto idx = argmin(x, axis, true, s);
|
||||||
|
return take_along_axis(tan, idx, axis, s);
|
||||||
|
} else if (reduce_type == Reduce::Max) {
|
||||||
|
auto idx = argmax(x, axis, true, s);
|
||||||
|
return take_along_axis(tan, idx, axis, s);
|
||||||
|
} else {
|
||||||
|
auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);
|
||||||
|
auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);
|
||||||
|
auto out = multiply(multiply(p1, p2, s), tan, s);
|
||||||
|
return sum(out, axis, true, s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto tan = tangents[0];
|
||||||
|
if (reduce_type_ == Reduce::Sum) {
|
||||||
|
return {sum(tan, axes_, true, s)};
|
||||||
|
} else {
|
||||||
|
if (axes_.size() > 1) {
|
||||||
|
std::vector<int> transpose_to;
|
||||||
|
{
|
||||||
|
// Find the transpose needed to move axes_ to the back.
|
||||||
|
int j = 0;
|
||||||
|
for (int i = 0; i < in.ndim(); i++) {
|
||||||
|
if (j < axes_.size() && axes_[j] == i) {
|
||||||
|
j++;
|
||||||
|
} else {
|
||||||
|
transpose_to.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto ax : axes_) {
|
||||||
|
transpose_to.push_back(ax);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int start_ax = in.ndim() - axes_.size();
|
||||||
|
in = flatten(transpose(in, transpose_to, s), start_ax, -1, s);
|
||||||
|
tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s);
|
||||||
|
|
||||||
|
auto grad = squeeze(grad_op(in, tan, -1), -1, s);
|
||||||
|
return {expand_dims(grad, axes_, s)};
|
||||||
|
} else {
|
||||||
|
return {grad_op(in, tan, axes_[0])};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
|||||||
@@ -1751,12 +1751,7 @@ class Reduce : public UnaryPrimitive {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS();
|
||||||
std::vector<array> vjp(
|
|
||||||
const std::vector<array>& primals,
|
|
||||||
const std::vector<array>& cotangents,
|
|
||||||
const std::vector<int>& argnums,
|
|
||||||
const std::vector<array>& outputs) override;
|
|
||||||
|
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
|
|||||||
@@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
grad_fn(model)
|
grad_fn(model)
|
||||||
self.assertEqual(model[1].item(), 2.0)
|
self.assertEqual(model[1].item(), 2.0)
|
||||||
|
|
||||||
|
def test_reduce_jvp(self):
|
||||||
|
a = mx.arange(4)
|
||||||
|
b = mx.array([3, 2, 1, 0])
|
||||||
|
|
||||||
|
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
|
||||||
|
self.assertEqual(jout[0].item(), 6)
|
||||||
|
|
||||||
|
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
|
||||||
|
self.assertEqual(jout[0].item(), 18)
|
||||||
|
|
||||||
|
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
|
||||||
|
self.assertEqual(jout[0].item(), 3)
|
||||||
|
|
||||||
|
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
|
||||||
|
self.assertEqual(jout[0].item(), 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user