add support for pynative pass add testcasestags/v1.0.0
| @@ -0,0 +1,144 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "frontend/optimizer/graph_transform.h" | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <string> | |||||
| #include "ir/graph_utils.h" | |||||
| namespace mindspore { | |||||
| /* namespace to support opt */ | |||||
| namespace opt { | |||||
| // check cnode input values, whether it is tuple input | |||||
| bool CNodeHasTupleInput(const CNodePtr &cnode) { | |||||
| auto &inputs = cnode->inputs(); | |||||
| for (size_t i = 1; i < inputs.size(); i++) { | |||||
| if (IsValueNode<FuncGraph>(inputs[i])) { | |||||
| continue; | |||||
| } | |||||
| if (IsValueNode<Primitive>(inputs[i])) { | |||||
| // unexpected high order primitvie as cnode input when transform graph | |||||
| MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitve as input" << cnode->DebugString(); | |||||
| return false; | |||||
| } | |||||
| auto abs = inputs[i]->abstract(); | |||||
| if (abs == nullptr) { | |||||
| MS_LOG(WARNING) << "CheckTupleInput, got abstract nullptr for node:" << cnode->DebugString(); | |||||
| return false; | |||||
| } | |||||
| if (abs->isa<abstract::AbstractTuple>()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool FuncGraphHasTupleInput(const FuncGraphPtr &fg) { | |||||
| auto ¶ms = fg->parameters(); | |||||
| for (auto ¶m : params) { | |||||
| if (param->abstract()->isa<abstract::AbstractTuple>()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::vector<AnfNodePtr> TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node, | |||||
| const abstract::AbstractTuplePtr &abs) { | |||||
| auto &elements = abs->elements(); | |||||
| std::vector<AnfNodePtr> tuple_node_expanded; | |||||
| for (size_t i = 0; i < elements.size(); i++) { | |||||
| auto elem_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(SizeToInt(i))}); | |||||
| elem_node->set_abstract(elements[i]); | |||||
| if (elements[i]->isa<abstract::AbstractTuple>()) { | |||||
| auto nodes = TransformTupleArgument(fg, elem_node, elements[i]->cast<abstract::AbstractTuplePtr>()); | |||||
| tuple_node_expanded.insert(tuple_node_expanded.end(), nodes.begin(), nodes.end()); | |||||
| } else { | |||||
| tuple_node_expanded.push_back(elem_node); | |||||
| } | |||||
| } | |||||
| return tuple_node_expanded; | |||||
| } | |||||
| AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) { | |||||
| auto &cinputs = cnode->inputs(); | |||||
| auto fg = cnode->func_graph(); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(trans_fg)); | |||||
| for (size_t i = 1; i < cinputs.size(); i++) { | |||||
| auto abs = cinputs[i]->abstract(); | |||||
| if (abs == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "TransformCallGraph:Node abstract should not be nullptr" << cinputs[i]->DebugString(); | |||||
| } | |||||
| if (abs->isa<abstract::AbstractTuple>()) { | |||||
| auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast<abstract::AbstractTuplePtr>()); | |||||
| inputs.insert(inputs.end(), nodes.begin(), nodes.end()); | |||||
| } else { | |||||
| inputs.push_back(cinputs[i]); | |||||
| } | |||||
| } | |||||
| auto new_node = fg->NewCNode(inputs); | |||||
| new_node->set_abstract(cnode->abstract()); | |||||
| return new_node; | |||||
| } | |||||
| AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) { | |||||
| auto &cinputs = cnode->inputs(); | |||||
| auto fg = cnode->func_graph(); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(prim::kPrimPartial)); | |||||
| inputs.push_back(NewValueNode(trans_fg)); | |||||
| for (size_t i = 2; i < cinputs.size(); i++) { | |||||
| auto abs = cinputs[i]->abstract(); | |||||
| if (abs == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "TransformPartial:Node abstract should not be nullptr" << cinputs[i]->DebugString(); | |||||
| } | |||||
| if (abs->isa<abstract::AbstractTuple>()) { | |||||
| auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast<abstract::AbstractTuplePtr>()); | |||||
| inputs.insert(inputs.end(), nodes.begin(), nodes.end()); | |||||
| } else { | |||||
| inputs.push_back(cinputs[i]); | |||||
| } | |||||
| } | |||||
| auto new_node = fg->NewCNode(inputs); | |||||
| new_node->set_abstract(cnode->abstract()); | |||||
| return new_node; | |||||
| } | |||||
| AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode) { | |||||
| auto &cinputs = cnode->inputs(); | |||||
| auto fg = cnode->func_graph(); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(swtich_node); | |||||
| for (size_t i = 1; i < cinputs.size(); i++) { | |||||
| auto abs = cinputs[i]->abstract(); | |||||
| if (abs == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "TransformSwitchCall:Node abstract should not be nullptr" << cinputs[i]->DebugString(); | |||||
| } | |||||
| if (abs->isa<abstract::AbstractTuple>()) { | |||||
| auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast<abstract::AbstractTuplePtr>()); | |||||
| inputs.insert(inputs.end(), nodes.begin(), nodes.end()); | |||||
| } else { | |||||
| inputs.push_back(cinputs[i]); | |||||
| } | |||||
| } | |||||
| auto new_node = fg->NewCNode(inputs); | |||||
| new_node->set_abstract(cnode->abstract()); | |||||
| return new_node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,108 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H | |||||
| #include <unordered_map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| bool CNodeHasTupleInput(const CNodePtr &cnode); | |||||
| bool FuncGraphHasTupleInput(const FuncGraphPtr &fg); | |||||
| std::vector<AnfNodePtr> TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node, | |||||
| const abstract::AbstractTuplePtr &abs); | |||||
| AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode); | |||||
| AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode); | |||||
| AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode); | |||||
| class GraphTupleParamTransform { | |||||
| public: | |||||
| GraphTupleParamTransform() : cache_() {} | |||||
| ~GraphTupleParamTransform() { cache_.clear(); } | |||||
| FuncGraphPtr operator()(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { | |||||
| if (cache_.find(fg) != cache_.end()) { | |||||
| return cache_[fg]; | |||||
| } | |||||
| auto new_fg = TransformGraphParam(fg, mng); | |||||
| cache_[fg] = new_fg; | |||||
| return new_fg; | |||||
| } | |||||
| AnfNodePtr GenerateTupleParams(const abstract::AbstractTuplePtr &tuple_abs, const FuncGraphPtr &fg, | |||||
| std::vector<AnfNodePtr> *params) { | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| auto &elements = tuple_abs->elements(); | |||||
| for (auto &item : elements) { | |||||
| if (item->isa<abstract::AbstractTuple>()) { | |||||
| inputs.push_back(GenerateTupleParams(item->cast<abstract::AbstractTuplePtr>(), fg, params)); | |||||
| } else { | |||||
| auto p = std::make_shared<Parameter>(fg); | |||||
| p->set_abstract(item); | |||||
| params->push_back(p); | |||||
| inputs.push_back(params->back()); | |||||
| } | |||||
| } | |||||
| auto node = fg->NewCNode(inputs); | |||||
| node->set_abstract(tuple_abs); | |||||
| return node; | |||||
| } | |||||
| FuncGraphPtr TransformGraphParam(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { | |||||
| Cloner cloner({fg}, false, false, false, std::make_shared<TraceCopy>(), std::make_shared<TraceCopy>()); | |||||
| auto new_fg = cloner[fg]; | |||||
| auto ¶ms = new_fg->parameters(); | |||||
| std::vector<AnfNodePtr> new_params; | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> repl; | |||||
| for (auto ¶m : params) { | |||||
| auto abs = param->abstract(); | |||||
| if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) { | |||||
| auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>(); | |||||
| std::vector<AnfNodePtr> tuple_params; | |||||
| repl.emplace(param, GenerateTupleParams(tuple_abs, new_fg, &tuple_params)); | |||||
| std::transform(tuple_params.begin(), tuple_params.end(), std::back_inserter(new_params), | |||||
| [](AnfNodePtr p) { return p; }); | |||||
| } else { | |||||
| new_params.push_back(param); | |||||
| } | |||||
| } | |||||
| auto tmp_mng = mindspore::Manage(new_fg, false); | |||||
| auto tr = tmp_mng->Transact(); | |||||
| for (auto &item : repl) { | |||||
| bool ret = tr.Replace(item.first, item.second); | |||||
| if (ret == false) { | |||||
| MS_LOG(ERROR) << "replace failed" << item.first->DebugString() << " with__" << item.second->DebugString(2); | |||||
| } | |||||
| } | |||||
| tr.SetParameters(new_fg, new_params); | |||||
| tr.Commit(); | |||||
| mng->AddFuncGraph(new_fg); | |||||
| return new_fg; | |||||
| } | |||||
| std::unordered_map<FuncGraphPtr, FuncGraphPtr> cache_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H | |||||
| @@ -44,6 +44,7 @@ | |||||
| #include "frontend/optimizer/irpass/row_tensor_eliminate.h" | #include "frontend/optimizer/irpass/row_tensor_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/switch_layer_defer_inline.h" | #include "frontend/optimizer/irpass/switch_layer_defer_inline.h" | ||||
| #include "frontend/optimizer/irpass/call_graph_tuple_transform.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -158,6 +159,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| unused_output_eliminate_ = | unused_output_eliminate_ = | ||||
| MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel); | MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel); | ||||
| // tuple parameter graph transform | |||||
| call_graph_tuple_transform_ = | |||||
| MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode); | |||||
| // AddN eliminate | // AddN eliminate | ||||
| addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel); | addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel); | ||||
| @@ -103,6 +103,9 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr unused_parameter_eliminate_; | SubstitutionPtr unused_parameter_eliminate_; | ||||
| SubstitutionPtr unused_output_eliminate_; | SubstitutionPtr unused_output_eliminate_; | ||||
| // tuple parameter graph transform | |||||
| SubstitutionPtr call_graph_tuple_transform_; | |||||
| // AddN eliminate | // AddN eliminate | ||||
| SubstitutionPtr addn_eliminate_; | SubstitutionPtr addn_eliminate_; | ||||
| @@ -0,0 +1,246 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_ | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include <vector> | |||||
| #include "ir/func_graph.h" | |||||
| #include "ir/func_graph_cloner.h" | |||||
| #include "frontend/optimizer/optimizer_caller.h" | |||||
| #include "frontend/optimizer/anf_visitor.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| #include "frontend/optimizer/irpass.h" | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| #include "frontend/optimizer/graph_transform.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| // {G, Xs}-->transform graph call tuple inputs to flat inputs. | |||||
| class GraphCallTupleTransform : public AnfVisitor { | |||||
| public: | |||||
| explicit GraphCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {} | |||||
| ~GraphCallTupleTransform() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto &inputs = cnode->inputs(); | |||||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||||
| if (fg == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!CNodeHasTupleInput(node->cast<CNodePtr>())) { | |||||
| return nullptr; | |||||
| } | |||||
| FuncGraphPtr transformed_fg = graph_transform_(fg, optimizer->manager()); | |||||
| auto new_node = TransformCallGraph(transformed_fg, node->cast<CNodePtr>()); | |||||
| return new_node; | |||||
| } | |||||
| private: | |||||
| GraphTupleParamTransform &graph_transform_; | |||||
| }; | |||||
| // {{switch, cond, true_branch, false_branch}, Xs} -->transform switch graph call tuple inputs to flat inputs. | |||||
| class SwitchCallTupleTransform : public AnfVisitor { | |||||
| public: | |||||
| explicit SwitchCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {} | |||||
| ~SwitchCallTupleTransform() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto switch_call_cnode = node->cast<CNodePtr>(); | |||||
| auto call_inputs = switch_call_cnode->inputs(); | |||||
| if (call_inputs.size() < 1) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitch)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto swich_cnode = call_inputs[0]->cast<CNodePtr>(); | |||||
| auto switch_inputs = swich_cnode->inputs(); | |||||
| if (switch_inputs.size() != 4) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr transformed = nullptr; | |||||
| bool true_br_changed = TransformBranchNode(switch_inputs[2], optimizer->manager(), &transformed); | |||||
| if (true_br_changed) { | |||||
| switch_inputs[2] = transformed; | |||||
| } | |||||
| bool false_br_changed = TransformBranchNode(switch_inputs[3], optimizer->manager(), &transformed); | |||||
| if (false_br_changed) { | |||||
| switch_inputs[3] = transformed; | |||||
| } | |||||
| if (true_br_changed || false_br_changed) { | |||||
| call_inputs[0] = swich_cnode->func_graph()->NewCNode(switch_inputs); | |||||
| } | |||||
| if (CNodeHasTupleInput(switch_call_cnode)) { | |||||
| return TransformSwitchCall(call_inputs[0], switch_call_cnode); | |||||
| } | |||||
| if (true_br_changed || false_br_changed) { | |||||
| return switch_call_cnode->func_graph()->NewCNode(call_inputs); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| bool TransformBranchNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) { | |||||
| if (IsValueNode<FuncGraph>(node)) { | |||||
| FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node); | |||||
| if (FuncGraphHasTupleInput(fg)) { | |||||
| FuncGraphPtr transformed_fg = graph_transform_(fg, mng); | |||||
| *trans_node = NewValueNode(transformed_fg); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| if (IsPrimitiveCNode(node, prim::kPrimPartial)) { | |||||
| auto partial_inputs = node->cast<CNodePtr>()->inputs(); | |||||
| if (IsValueNode<FuncGraph>(partial_inputs[1])) { | |||||
| FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(partial_inputs[1]); | |||||
| if (FuncGraphHasTupleInput(fg)) { | |||||
| fg = graph_transform_(fg, mng); | |||||
| } | |||||
| if (CNodeHasTupleInput(node->cast<CNodePtr>())) { | |||||
| *trans_node = TransformPartial(fg, node->cast<CNodePtr>()); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| MS_LOG(WARNING) << "Got unexpected switch branch node " << node->DebugString(); | |||||
| return false; | |||||
| } | |||||
| private: | |||||
| GraphTupleParamTransform &graph_transform_; | |||||
| }; | |||||
| // {{switch_layer, index, {make_tuple, br1, br2,...,}}, Xs} -> | |||||
| // transform switch layer graph call tuple inputs to flat inputs. | |||||
| class SwitchLayerCallTupleTransform : public AnfVisitor { | |||||
| public: | |||||
| explicit SwitchLayerCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {} | |||||
| ~SwitchLayerCallTupleTransform() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto switch_layer_call_cnode = node->cast<CNodePtr>(); | |||||
| auto call_inputs = switch_layer_call_cnode->inputs(); | |||||
| if (call_inputs.size() < 1) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitchLayer)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto swich_layer_cnode = call_inputs[0]->cast<CNodePtr>(); | |||||
| auto switch_layer_inputs = swich_layer_cnode->inputs(); | |||||
| if (switch_layer_inputs.size() != 3) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr transformed = nullptr; | |||||
| bool layer_changed = TransformLayerNode(switch_layer_inputs[2], optimizer->manager(), &transformed); | |||||
| if (layer_changed) { | |||||
| switch_layer_inputs[2] = transformed; | |||||
| call_inputs[0] = switch_layer_call_cnode->func_graph()->NewCNode(switch_layer_inputs); | |||||
| } | |||||
| if (CNodeHasTupleInput(switch_layer_call_cnode)) { | |||||
| return TransformSwitchCall(call_inputs[0], switch_layer_call_cnode); | |||||
| } | |||||
| if (layer_changed) { | |||||
| return switch_layer_call_cnode->func_graph()->NewCNode(call_inputs); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| bool TransformLayerNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | |||||
| MS_LOG(WARNING) << "SwitchLayer input is not MakeTuple"; | |||||
| return false; | |||||
| } | |||||
| auto tuple_inputs = node->cast<CNodePtr>()->inputs(); | |||||
| bool changed = false; | |||||
| for (size_t i = 1; i < tuple_inputs.size(); i++) { | |||||
| if (!IsValueNode<FuncGraph>(tuple_inputs[i])) { | |||||
| MS_LOG(WARNING) << "SwitchLayer input is not FuncGraph"; | |||||
| return false; | |||||
| } | |||||
| FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(tuple_inputs[i]); | |||||
| if (FuncGraphHasTupleInput(fg)) { | |||||
| FuncGraphPtr transformed_fg = graph_transform_(fg, mng); | |||||
| tuple_inputs[i] = NewValueNode(transformed_fg); | |||||
| changed = true; | |||||
| } | |||||
| } | |||||
| if (changed) { | |||||
| *trans_node = node->func_graph()->NewCNode(tuple_inputs); | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| private: | |||||
| GraphTupleParamTransform &graph_transform_; | |||||
| }; | |||||
| class CallGraphTupleTransform : public OptimizerCaller { | |||||
| public: | |||||
| CallGraphTupleTransform() | |||||
| : graph_transformer_(), | |||||
| graph_call_transform_(std::make_shared<GraphCallTupleTransform>(graph_transformer_)), | |||||
| switch_call_transform_(std::make_shared<SwitchCallTupleTransform>(graph_transformer_)), | |||||
| switch_layer_call_transform_(std::make_shared<SwitchLayerCallTupleTransform>(graph_transformer_)) { | |||||
| transformers_.emplace_back(graph_call_transform_); | |||||
| transformers_.emplace_back(switch_call_transform_); | |||||
| transformers_.emplace_back(switch_layer_call_transform_); | |||||
| } | |||||
| ~CallGraphTupleTransform() = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| AnfNodePtr new_node; | |||||
| for (auto &transform : transformers_) { | |||||
| new_node = (*transform)(optimizer, node); | |||||
| if (new_node != nullptr) { | |||||
| return new_node; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| private: | |||||
| GraphTupleParamTransform graph_transformer_; | |||||
| OptimizerCallerPtr graph_call_transform_; | |||||
| OptimizerCallerPtr switch_call_transform_; | |||||
| OptimizerCallerPtr switch_layer_call_transform_; | |||||
| std::vector<OptimizerCallerPtr> transformers_{}; | |||||
| }; | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_ | |||||
| @@ -277,6 +277,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| func_graph->DumpFuncGraph(fg_name); | func_graph->DumpFuncGraph(fg_name); | ||||
| DumpIR(fg_name + ".ir", func_graph); | DumpIR(fg_name + ".ir", func_graph); | ||||
| ExportIR(fg_name + ".dat", "", func_graph); | |||||
| MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; | MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; | ||||
| } | } | ||||
| counter++; | counter++; | ||||
| @@ -33,6 +33,7 @@ | |||||
| #include "frontend/optimizer/clean.h" | #include "frontend/optimizer/clean.h" | ||||
| #include "frontend/optimizer/irpass.h" | #include "frontend/optimizer/irpass.h" | ||||
| #include "frontend/optimizer/control_depend.h" | #include "frontend/optimizer/control_depend.h" | ||||
| #include "frontend/optimizer/graph_transform.h" | |||||
| #include "frontend/parallel/step_parallel.h" | #include "frontend/parallel/step_parallel.h" | ||||
| #include "frontend/parallel/step_auto_parallel.h" | #include "frontend/parallel/step_auto_parallel.h" | ||||
| #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" | #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" | ||||
| @@ -166,12 +167,23 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) { | OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) { | ||||
| opt::OptPassConfig c_1 = opt::OptPassConfig({ | opt::OptPassConfig c_1 = opt::OptPassConfig({ | ||||
| // Safe inlining | |||||
| // Safe inlining, | |||||
| irpass.inline_, | irpass.inline_, | ||||
| irpass.partial_eliminate_, | irpass.partial_eliminate_, | ||||
| }); | }); | ||||
| OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); | |||||
| OptPassGroupMap map_a({{"c_1", c_1}, | |||||
| {"cse", opt::OptPassConfig(opt::CSEPass(false))}, | |||||
| {"renormalize", opt::OptPassConfig::Renormalize()}}); | |||||
| return map_a; | |||||
| } | |||||
| OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining | |||||
| irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_}); | |||||
| OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); | |||||
| return map_a; | return map_a; | ||||
| } | } | ||||
| @@ -262,6 +274,8 @@ void InitOpt(const ResourcePtr &res) { | |||||
| g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); | g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); | ||||
| g_pass_opts["opt_after_cconv"] = | g_pass_opts["opt_after_cconv"] = | ||||
| Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true); | Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true); | ||||
| g_pass_opts["opt_trans_graph"] = | |||||
| Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true); | |||||
| g_pass_opts["opt_graph_kernel_a"] = | g_pass_opts["opt_graph_kernel_a"] = | ||||
| Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); | Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); | ||||
| g_pass_opts["opt_graph_kernel_b"] = | g_pass_opts["opt_graph_kernel_b"] = | ||||
| @@ -307,6 +321,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { | |||||
| bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } | bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } | ||||
| bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } | bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } | ||||
| bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); } | bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); } | ||||
| bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); } | |||||
| bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } | bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } | ||||
| bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } | bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } | ||||
| bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } | bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } | ||||
| @@ -365,6 +380,24 @@ bool CconvPass(const ResourcePtr &res) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool TransformTopGraphPass(const ResourcePtr &res) { | |||||
| if (res->func_graph() == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Transform top graph error."; | |||||
| } | |||||
| FuncGraphPtr func_graph = res->func_graph(); | |||||
| if (opt::FuncGraphHasTupleInput(func_graph)) { | |||||
| opt::GraphTupleParamTransform graph_trans; | |||||
| func_graph = graph_trans(func_graph, res->manager()); | |||||
| res->set_func_graph(func_graph); | |||||
| AbstractBasePtrList abs_spec_list; | |||||
| auto ¶ms = func_graph->parameters(); | |||||
| std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list), | |||||
| [](AnfNodePtr node) { return node->abstract(); }); | |||||
| res->set_args_spec(abs_spec_list); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ValidatePass(const ResourcePtr &res) { | bool ValidatePass(const ResourcePtr &res) { | ||||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | MS_EXCEPTION_IF_NULL(res->func_graph()); | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| @@ -388,6 +421,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru | |||||
| {"cconv", CconvPass}, | {"cconv", CconvPass}, | ||||
| {"opt_after_cconv", OptPassAfterCconvGroup}, | {"opt_after_cconv", OptPassAfterCconvGroup}, | ||||
| {"remove_dup_value", RemoveValueNodeDuplicationsPass}, | {"remove_dup_value", RemoveValueNodeDuplicationsPass}, | ||||
| {"tuple_transform", OptPassTransformGraphGroup}, | |||||
| {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | ||||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | ||||
| {"add_control_depend", AddControlDependPass}}; | {"add_control_depend", AddControlDependPass}}; | ||||
| @@ -401,6 +435,10 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru | |||||
| {"opt_prepare", PrepareGroup}, | {"opt_prepare", PrepareGroup}, | ||||
| {"cconv", CconvPass}}; | {"cconv", CconvPass}}; | ||||
| std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; | |||||
| std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, | |||||
| {"opt_b", OptPassBGroup}, | |||||
| {"cconv", CconvPass}, | |||||
| {"transform_top", TransformTopGraphPass}, | |||||
| {"transform_graph", OptPassTransformGraphGroup}}; | |||||
| } // namespace pipeline | } // namespace pipeline | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1351,9 +1351,46 @@ void PynativeExecutor::ClearRes() { | |||||
| resource_.reset(); | resource_.reset(); | ||||
| } | } | ||||
| size_t GetTupleSize(const py::tuple &args) { | |||||
| size_t count = 0; | |||||
| for (size_t i = 0; i < args.size(); i++) { | |||||
| if (py::isinstance<py::tuple>(args[i])) { | |||||
| count += GetTupleSize(args[i]); | |||||
| } else { | |||||
| count += 1; | |||||
| } | |||||
| } | |||||
| return count; | |||||
| } | |||||
| void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) { | |||||
| for (size_t i = 0; i < arg.size(); i++) { | |||||
| if (py::isinstance<py::tuple>(arg[i])) { | |||||
| ConvertTupleArg(res, index, arg[i]); | |||||
| } else { | |||||
| (*res)[(*index)++] = arg[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| py::tuple ConvertArgs(const py::tuple &args) { | |||||
| size_t tuple_size = GetTupleSize(args); | |||||
| py::tuple res(tuple_size); | |||||
| size_t index = 0; | |||||
| for (size_t i = 0; i < args.size(); i++) { | |||||
| if (py::isinstance<py::tuple>(args[i])) { | |||||
| ConvertTupleArg(&res, &index, args[i]); | |||||
| } else { | |||||
| res[index++] = args[i]; | |||||
| } | |||||
| } | |||||
| return res; | |||||
| } | |||||
| py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) { | py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) { | ||||
| VectorRef arg_list; | VectorRef arg_list; | ||||
| pipeline::ProcessVmArgInner(args, resource_, &arg_list); | |||||
| py::tuple converted_args = ConvertArgs(args); | |||||
| pipeline::ProcessVmArgInner(converted_args, resource_, &arg_list); | |||||
| if (resource_->results().find(pipeline::kOutput) == resource_->results().end() || | if (resource_->results().find(pipeline::kOutput) == resource_->results().end() || | ||||
| !resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) { | !resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) { | ||||
| MS_LOG(EXCEPTION) << "Can't find run graph func for "; | MS_LOG(EXCEPTION) << "Can't find run graph func for "; | ||||
| @@ -0,0 +1,201 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import pytest | |||||
| import numpy as np | |||||
| from mindspore import RowTensor | |||||
| from mindspore import context, nn, Tensor, ParameterTuple | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common import ms_function | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | |||||
| def setup_module(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) | |||||
| class _Grad(nn.Cell): | |||||
| def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): | |||||
| super().__init__() | |||||
| self.network = network | |||||
| self.grad = grad | |||||
| self.sens_param = self.grad.sens_param | |||||
| self.wrt_params = wrt_params | |||||
| self.real_inputs_count = real_inputs_count | |||||
| if self.wrt_params: | |||||
| self.params = ParameterTuple(self.network.trainable_params()) | |||||
| def construct(self, *inputs): | |||||
| if self.wrt_params: | |||||
| if self.real_inputs_count is None or self.sens_param is False: | |||||
| return self.grad(self.network, self.params)(*inputs) | |||||
| real_inputs = inputs[:self.real_inputs_count] | |||||
| sense_param_inputs = inputs[self.real_inputs_count:] | |||||
| return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) | |||||
| if self.real_inputs_count is None or self.sens_param is False: | |||||
| return self.grad(self.network)(*inputs) | |||||
| real_inputs = inputs[:self.real_inputs_count] | |||||
| sense_param_inputs = inputs[self.real_inputs_count:] | |||||
| return self.grad(self.network)(*real_inputs, sense_param_inputs) | |||||
| class GradOfFirstInput(_Grad): | |||||
| """ | |||||
| get grad of first input | |||||
| """ | |||||
| def __init__(self, network, sens_param=True, real_inputs_count=None): | |||||
| super().__init__(grad=C.GradOperation(sens_param=sens_param), | |||||
| network=network, real_inputs_count=real_inputs_count) | |||||
| class GradOfAllInputs(_Grad): | |||||
| """ | |||||
| get grad of first input | |||||
| """ | |||||
| def __init__(self, network, sens_param=True, real_inputs_count=None): | |||||
| super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param), | |||||
| network=network, real_inputs_count=real_inputs_count) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_row_tensor_in_while(): | |||||
| class RowTensorValuesDouble(nn.Cell): | |||||
| def construct(self, x): | |||||
| indices = x.indices | |||||
| values = x.values * 2 | |||||
| dense_shape = x.dense_shape | |||||
| return RowTensor(indices, values, dense_shape) | |||||
| class RowTensorValuesAdd2(nn.Cell): | |||||
| def construct(self, x): | |||||
| indices = x.indices | |||||
| values = x.values + 2 | |||||
| dense_shape = x.dense_shape | |||||
| return RowTensor(indices, values, dense_shape) | |||||
| class RowTensorWithControlWhile(nn.Cell): | |||||
| def __init__(self, dense_shape): | |||||
| super().__init__() | |||||
| self.op1 = RowTensorValuesDouble() | |||||
| self.op2 = RowTensorValuesAdd2() | |||||
| self.dense_shape = dense_shape | |||||
| @ms_function | |||||
| def construct(self, a, b, indices, values): | |||||
| x = RowTensor(indices, values, self.dense_shape) | |||||
| x = self.op2(x) | |||||
| while a > b: | |||||
| x = self.op1(x) | |||||
| b = b + 1 | |||||
| return x.indices, x.values, x.dense_shape | |||||
| a = Tensor(np.array(3).astype(np.int32)) | |||||
| b = Tensor(np.array(0).astype(np.int32)) | |||||
| indices = Tensor(np.array([0, 2]).astype(np.int32)) | |||||
| values = Tensor(np.ones([2, 2]).astype(np.float32)) | |||||
| dense_shape = (5, 2) | |||||
| net = RowTensorWithControlWhile(dense_shape) | |||||
| out = net(a, b, indices, values) | |||||
| assert np.allclose(indices.asnumpy(), out[0].asnumpy(), .0, .0) | |||||
| assert np.allclose(values.asnumpy()*24, out[1].asnumpy(), .0, .0) | |||||
| assert dense_shape == out[2] | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_parser_switch_layer_inputs_tuple(): | |||||
| class Add(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.op = P.TensorAdd() | |||||
| def construct(self, x): | |||||
| y = self.op(x[0], x[1]) | |||||
| return self.op(x[0], y) | |||||
| class Mul(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.op = P.Mul() | |||||
| def construct(self, x): | |||||
| y = self.op(x[0], x[1]) | |||||
| return self.op(x[0], y) | |||||
| class MulTwoInput(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.op = P.Mul() | |||||
| @ms_function | |||||
| def construct(self, x, y): | |||||
| y = self.op(x, y) | |||||
| return self.op(x, y) | |||||
| class TwoInputTupleFinalNet(nn.Cell): | |||||
| def __init__(self, funcs): | |||||
| super().__init__() | |||||
| self.funcs = funcs | |||||
| @ms_function | |||||
| def construct(self, i, inputa, inputb): | |||||
| inputs = (inputa, inputb) | |||||
| x = self.funcs[i](inputs) | |||||
| return x | |||||
| func1 = Add() | |||||
| func2 = Mul() | |||||
| funcs = (func1, func2) | |||||
| net = TwoInputTupleFinalNet(funcs) | |||||
| input_data = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) | |||||
| input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) | |||||
| i = Tensor(1, mstype.int32) | |||||
| netout = net(i, input_data, input2) | |||||
| net_good = MulTwoInput() | |||||
| goodout = net_good(input_data, input2) | |||||
| assert np.allclose(goodout.asnumpy(), netout.asnumpy(), 0, 0) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_imagenet(): | |||||
| class ImageGradients(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.imagegradients = nn.ImageGradients() | |||||
| def construct(self, inputs): | |||||
| return self.imagegradients(inputs) | |||||
| net = ImageGradients() | |||||
| net_me = GradOfFirstInput(net, real_inputs_count=1) | |||||
| net_me.set_train() | |||||
| input_data = Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32) | |||||
| output_grad = (Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32), | |||||
| Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32)) | |||||
| net_me(input_data, *output_grad) | |||||
| @@ -0,0 +1,136 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| from mindspore import RowTensor | |||||
| from mindspore import context, nn, Tensor, ParameterTuple | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common import ms_function | |||||
| from mindspore.ops import composite as C | |||||
| def setup_module(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) | |||||
| class _Grad(nn.Cell): | |||||
| def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): | |||||
| super().__init__() | |||||
| self.network = network | |||||
| self.grad = grad | |||||
| self.sens_param = self.grad.sens_param | |||||
| self.wrt_params = wrt_params | |||||
| self.real_inputs_count = real_inputs_count | |||||
| if self.wrt_params: | |||||
| self.params = ParameterTuple(self.network.trainable_params()) | |||||
| def construct(self, *inputs): | |||||
| if self.wrt_params: | |||||
| if self.real_inputs_count is None or self.sens_param is False: | |||||
| return self.grad(self.network, self.params)(*inputs) | |||||
| real_inputs = inputs[:self.real_inputs_count] | |||||
| sense_param_inputs = inputs[self.real_inputs_count:] | |||||
| return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) | |||||
| if self.real_inputs_count is None or self.sens_param is False: | |||||
| return self.grad(self.network)(*inputs) | |||||
| real_inputs = inputs[:self.real_inputs_count] | |||||
| sense_param_inputs = inputs[self.real_inputs_count:] | |||||
| return self.grad(self.network)(*real_inputs, sense_param_inputs) | |||||
| class GradOfFirstInput(_Grad): | |||||
| """ | |||||
| get grad of first input | |||||
| """ | |||||
| def __init__(self, network, sens_param=True, real_inputs_count=None): | |||||
| super().__init__(grad=C.GradOperation(sens_param=sens_param), | |||||
| network=network, real_inputs_count=real_inputs_count) | |||||
| class GradOfAllInputs(_Grad): | |||||
| """ | |||||
| get grad of first input | |||||
| """ | |||||
| def __init__(self, network, sens_param=True, real_inputs_count=None): | |||||
| super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param), | |||||
| network=network, real_inputs_count=real_inputs_count) | |||||
| def test_row_tensor_in_while(): | |||||
| class RowTensorValuesDouble(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def construct(self, x): | |||||
| indices = x.indices | |||||
| values = x.values * 2 | |||||
| dense_shape = x.dense_shape | |||||
| return RowTensor(indices, values, dense_shape) | |||||
| class RowTensorValuesAdd2(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def construct(self, x): | |||||
| indices = x.indices | |||||
| values = x.values + 2 | |||||
| dense_shape = x.dense_shape | |||||
| return RowTensor(indices, values, dense_shape) | |||||
| class RowTensorWithControlWhile(nn.Cell): | |||||
| def __init__(self, dense_shape): | |||||
| super().__init__() | |||||
| self.op1 = RowTensorValuesDouble() | |||||
| self.op2 = RowTensorValuesAdd2() | |||||
| self.dense_shape = dense_shape | |||||
| @ms_function | |||||
| def construct(self, a, b, indices, values): | |||||
| x = RowTensor(indices, values, self.dense_shape) | |||||
| x = self.op2(x) | |||||
| while (a > b): | |||||
| x = self.op1(x) | |||||
| b = b + 1 | |||||
| return x.indices, x.values, x.dense_shape | |||||
| a = Tensor(np.array(3).astype(np.int32)) | |||||
| b = Tensor(np.array(0).astype(np.int32)) | |||||
| indices = Tensor(np.array([0, 2]).astype(np.int32)) | |||||
| values = Tensor(np.ones([2, 2]).astype(np.float32)) | |||||
| dense_shape = (5, 2) | |||||
| net = RowTensorWithControlWhile(dense_shape) | |||||
| net(a, b, indices, values) | |||||
| def test_multi_out_sens(): | |||||
| class ImageGradients(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def construct(self, x, y, z): | |||||
| resa = x * y | |||||
| resb = y * z | |||||
| resc = x * z | |||||
| return resa, (resb, resc) | |||||
| net = ImageGradients() | |||||
| net_me = GradOfAllInputs(net, real_inputs_count=3) | |||||
| net_me.set_train() | |||||
| input_data = Tensor(np.ones([32]), dtype=mstype.float32) | |||||
| output_grad = (Tensor(np.ones([32]), dtype=mstype.float32), | |||||
| (Tensor(np.ones([32]), dtype=mstype.float32), Tensor(np.ones([32]), dtype=mstype.float32))) | |||||
| net_me(input_data, input_data, input_data, *output_grad) | |||||