diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 811252c55b9b224ffe09c08eabf95ef81d958e9f..c34cc8a482530d3d36fde00373a490206cd0499e 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -71,6 +71,11 @@ void DFunctor::Init(bool is_top) { } } +void DFunctor::Finish() { + CallDoutHoleOnTape(); + EliminatePrimalGraph(); +} + void DFunctor::Clear() { func_graph_to_functor_.clear(); anfnode_to_adjoin_definition_.clear(); @@ -728,10 +733,7 @@ void DFunctor::CallDoutHoleOnTape() { } } } -FuncGraphPtr DFunctor::k_graph() { - CallDoutHoleOnTape(); - return k_graph_; -} +FuncGraphPtr DFunctor::k_graph() { return k_graph_; } void DFunctor::BroadCastStopFlag() { // As stop set expanding, all directly or indirectly stopped CNode will be cut off @@ -768,5 +770,28 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) { } return true; } + +// To replace the primal graph with k graph +void DFunctor::EliminatePrimalGraph() { + auto k_vnode = NewValueNode(k_graph_); + auto idx0 = NewValueNode(SizeToInt(0)); + auto imm0 = std::make_shared(0); + idx0->set_abstract(std::make_shared(imm0)); + auto manager = primal_graph_->manager(); + auto users = primal_graph_->func_graph_cnodes_index(); + for (auto &it : users) { + auto cnode = it.first->first->cast(); + auto index = it.first->second; + auto vnode = cnode->inputs()[index]; + if (index != 0) { + MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}"; + continue; + } + cnode->set_input(0, k_vnode); // Replace primal graph with k graph + auto construct_wrapper = cnode->func_graph(); + auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0}); + manager->Replace(cnode, getitem0); + } +} } // namespace ad } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 70be856a293d06c40fd1c0bbf99c94e873e668d6..100da3a29c84af341469d213c4a4f7c9385df600 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -64,6 +64,7 @@ class DFunctor : public std::enable_shared_from_this { FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); // Register functor objects to form a global view. void Init(bool is_top = false); + void Finish(); bool IsInScope(const AnfNodePtr &node); // Clear resources. @@ -97,6 +98,8 @@ class DFunctor : public std::enable_shared_from_this { void UpdateAdjoint(const AdjointPtr &adjoint_definition); void CallDoutHoleOnTape(); void ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph); + // Replace the primal graph with k graph + void EliminatePrimalGraph(); std::unordered_map anfnode_to_adjoin_; // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc index c41efe4930c22433c383ca4e515bdf6259f5bc5b..12654232ef6a41be4e9b408bcbe32fb3f88685fc 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc @@ -53,6 +53,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt f->Init(is_top); f->MapObject(); f->MapMorphism(); + f->Finish(); auto ret = f->k_graph(); if (is_top) { DFunctor::Clear();