|
|
|
@@ -27,15 +27,15 @@ namespace mindspore { |
|
|
|
namespace ad { |
|
|
|
class PynativeDFunctor { |
|
|
|
public: |
|
|
|
ValueNodePtr GenNewTensor(const CNodePtr &forward_node); |
|
|
|
tensor::TensorPtr GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem); |
|
|
|
void GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph, |
|
|
|
FuncGraphPtr *fprop_graph); |
|
|
|
std::vector<AnfNodePtr> RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph, |
|
|
|
const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph); |
|
|
|
std::vector<AnfNodePtr> RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph, |
|
|
|
const CNodePtr &cnode_morph); |
|
|
|
void ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph); |
|
|
|
static ValueNodePtr GenNewTensor(const CNodePtr &forward_node); |
|
|
|
static tensor::TensorPtr GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem); |
|
|
|
static void GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph, |
|
|
|
FuncGraphPtr *fprop_graph); |
|
|
|
static std::vector<AnfNodePtr> RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph, |
|
|
|
const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph); |
|
|
|
static std::vector<AnfNodePtr> RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph, |
|
|
|
const CNodePtr &cnode_morph); |
|
|
|
static void ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph); |
|
|
|
}; |
|
|
|
} // namespace ad |
|
|
|
} // namespace mindspore |
|
|
|
|