Merge pull request !62 from amongo/SupportGradOnVarArgstags/v0.2.0-alpha
| @@ -1199,51 +1199,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec | |||
| return ret_graph; | |||
| } | |||
| FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| // slice a tensor | |||
| // args: tensor, slice or slice tuple | |||
| const std::string op_name = std::string("UnpackCall"); | |||
| size_t arg_length = args_spec_list.size(); | |||
| if (arg_length < 2) { | |||
| MS_LOG(EXCEPTION) << "" << op_name << " requires at least two args, but got " << arg_length << "."; | |||
| } | |||
| (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); | |||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| AnfNodePtr fnNode = ret_graph->add_parameter(); | |||
| std::vector<AnfNodePtr> elems; | |||
| elems.push_back(fnNode); | |||
| for (size_t index = 1; index < arg_length; index++) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[index]); | |||
| if (args_spec_list[index]->isa<AbstractTuple>()) { | |||
| AbstractTuplePtr arg_tuple = dyn_cast<AbstractTuple>(args_spec_list[index]); | |||
| AnfNodePtr para_tuple = ret_graph->add_parameter(); | |||
| for (size_t i = 0; i < arg_tuple->size(); ++i) { | |||
| elems.push_back( | |||
| ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); | |||
| } | |||
| } else if (args_spec_list[index]->isa<AbstractDictionary>()) { | |||
| AbstractDictionaryPtr arg_dict = dyn_cast<AbstractDictionary>(args_spec_list[index]); | |||
| AnfNodePtr para_dict = ret_graph->add_parameter(); | |||
| auto dict_elems = arg_dict->elements(); | |||
| (void)std::transform( | |||
| dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), | |||
| [ret_graph, para_dict](const AbstractAttribute& item) { | |||
| return ret_graph->NewCNode( | |||
| {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), | |||
| ret_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)})}); | |||
| }); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "" << op_name << " require args should be tuple or dict, but got " | |||
| << args_spec_list[index]->ToString(); | |||
| } | |||
| } | |||
| ret_graph->set_output(ret_graph->NewCNode(elems)); | |||
| return ret_graph; | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| TupleAdd_, ([](const py::module* m) { | |||
| (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>()); | |||
| @@ -1258,10 +1213,5 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) { | |||
| (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | |||
| .def(py::init<std::string&>()); | |||
| })); | |||
| REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { | |||
| (void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_") | |||
| .def(py::init<std::string&>()); | |||
| })); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -29,6 +29,7 @@ | |||
| #include "operator/composite/zip_operation.h" | |||
| #include "operator/composite/list_append_operation.h" | |||
| #include "operator/composite/do_signature.h" | |||
| #include "operator/composite/unpack_call.h" | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| #include "utils/misc.h" | |||
| #include "utils/any.h" | |||
| @@ -154,7 +155,7 @@ class GradOperation : public MetaFuncGraph { | |||
| FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams, | |||
| bool applyJ = false); | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| bool sens_param() const { return sens_param_; } | |||
| bool get_all_; | |||
| bool get_by_list_; | |||
| bool sens_param_; | |||
| @@ -208,17 +209,6 @@ class TensorSlice : public MetaFuncGraph { | |||
| }; | |||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | |||
| // Expand the tuple and dict parameters generated when parsing the function call, | |||
| // and generate positional parameters and key-value pairs for function. | |||
| class UnpackCall : public MetaFuncGraph { | |||
| public: | |||
| explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} | |||
| ~UnpackCall() override = default; | |||
| MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using UnpackCallPtr = std::shared_ptr<UnpackCall>; | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,94 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "operator/composite/unpack_call.h" | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include "./common.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pipeline/static_analysis/dshape.h" | |||
| #include "pipeline/static_analysis/param_validator.h" | |||
| #include "operator/cc_implementations.h" | |||
| #include "ir/anf.h" | |||
| #include "optimizer/opt.h" | |||
| #include "utils/symbolic.h" | |||
| #include "pybind_api/api_register.h" | |||
| namespace mindspore { | |||
| // namespace to support composite operators definition | |||
| namespace prim { | |||
| using mindspore::abstract::AbstractAttribute; | |||
| using mindspore::abstract::AbstractBase; | |||
| using mindspore::abstract::AbstractDictionary; | |||
| using mindspore::abstract::AbstractDictionaryPtr; | |||
| using mindspore::abstract::AbstractFunction; | |||
| using mindspore::abstract::AbstractKeywordArg; | |||
| using mindspore::abstract::AbstractTuple; | |||
| using mindspore::abstract::AbstractTuplePtr; | |||
| FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| // slice a tensor | |||
| // args: tensor, slice or slice tuple | |||
| const std::string op_name = std::string("UnpackCall"); | |||
| size_t arg_length = args_spec_list.size(); | |||
| if (arg_length < 2) { | |||
| MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; | |||
| } | |||
| (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); | |||
| auto ret_graph = std::make_shared<FuncGraph>(); | |||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| AnfNodePtr fnNode = ret_graph->add_parameter(); | |||
| std::vector<AnfNodePtr> elems; | |||
| elems.push_back(fnNode); | |||
| for (size_t index = 1; index < arg_length; index++) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[index]); | |||
| if (args_spec_list[index]->isa<AbstractTuple>()) { | |||
| auto arg_tuple = args_spec_list[index]->cast<AbstractTuplePtr>(); | |||
| AnfNodePtr para_tuple = ret_graph->add_parameter(); | |||
| for (size_t i = 0; i < arg_tuple->size(); ++i) { | |||
| elems.push_back( | |||
| ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); | |||
| } | |||
| } else if (args_spec_list[index]->isa<AbstractDictionary>()) { | |||
| AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast<AbstractDictionaryPtr>(); | |||
| AnfNodePtr para_dict = ret_graph->add_parameter(); | |||
| auto dict_elems = arg_dict->elements(); | |||
| (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), | |||
| [ret_graph, para_dict](const AbstractAttribute& item) { | |||
| auto dict_get_item = ret_graph->NewCNode( | |||
| {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); | |||
| return ret_graph->NewCNode( | |||
| {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item}); | |||
| }); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got " | |||
| << args_spec_list[index]->ToString(); | |||
| } | |||
| } | |||
| ret_graph->set_output(ret_graph->NewCNode(elems)); | |||
| return ret_graph; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { | |||
| (void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_") | |||
| .def(py::init<std::string&>()); | |||
| })); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ | |||
| #define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <map> | |||
| #include <set> | |||
| #include <memory> | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| #include "utils/misc.h" | |||
| #include "utils/any.h" | |||
| #include "ir/dtype.h" | |||
| #include "ir/meta_func_graph.h" | |||
| #include "common/utils.h" | |||
| namespace mindspore { | |||
| // namespace to support composite operators definition | |||
| namespace prim { | |||
| // Expand the tuple and dict parameters generated when parsing the function call, | |||
| // and generate positional parameters and key-value pairs for function. | |||
| class UnpackCall : public MetaFuncGraph { | |||
| public: | |||
| explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} | |||
| ~UnpackCall() override = default; | |||
| MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using UnpackCallPtr = std::shared_ptr<UnpackCall>; | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ | |||
| @@ -246,6 +246,21 @@ class DoSignaturePrimitive : public Primitive { | |||
| ValuePtr function_; | |||
| }; | |||
| using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | |||
| class UnpackGraphPrimitive : public Primitive { | |||
| public: | |||
| explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args) | |||
| : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} | |||
| ~UnpackGraphPrimitive() override = default; | |||
| MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) | |||
| bool with_sens_in_args() const { return with_sens_in_args_; } | |||
| bool need_unpack_args() const { return need_unpack_args_; } | |||
| private: | |||
| bool with_sens_in_args_; | |||
| bool need_unpack_args_; | |||
| }; | |||
| using UnpackGraphPrimitivePtr = std::shared_ptr<UnpackGraphPrimitive>; | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -39,6 +39,7 @@ | |||
| #include "optimizer/irpass/specialize_transform.h" | |||
| #include "optimizer/irpass/incorporate_getitem.h" | |||
| #include "optimizer/irpass/incorporate_call.h" | |||
| #include "optimizer/irpass/grad_var_prepare.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -123,6 +124,11 @@ ResolveIRPassLib::ResolveIRPassLib() { | |||
| resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); | |||
| resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); | |||
| } | |||
| InferenceOptPrepareLib::InferenceOptPrepareLib() { | |||
| grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -102,6 +102,13 @@ class ResolveIRPassLib { | |||
| SubstitutionPtr resolver_getattr_; | |||
| }; | |||
| class InferenceOptPrepareLib { | |||
| public: | |||
| InferenceOptPrepareLib(); | |||
| ~InferenceOptPrepareLib() = default; | |||
| SubstitutionPtr grad_var_prepare_; | |||
| }; | |||
| // predicate functions | |||
| inline bool IsNode(const AnfNodePtr &) { return true; } | |||
| @@ -151,6 +158,7 @@ inline bool IsCNodeDup(const AnfNodePtr &node) { | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,144 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "optimizer/irpass/grad_var_prepare.h" | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "operator/composite/composite.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, FuncGraphPtr func_graph, | |||
| AnfNodePtr func_node, bool is_unpack, bool sens_param) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(func_node); | |||
| std::vector<AnfNodePtr> nodes; | |||
| AnfNodePtr unpack_graph_node = nullptr; | |||
| if (is_unpack) { | |||
| auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, true); | |||
| nodes.push_back(NewValueNode(unpack_graph)); | |||
| nodes.push_back(func_node); | |||
| // {unpackcall, {GradOperation, ...}, args...} | |||
| std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), | |||
| [](const AnfNodePtr& node) { return node; }); | |||
| unpack_graph_node = func_graph->NewCNode(nodes); | |||
| } else { | |||
| auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false); | |||
| nodes.push_back(NewValueNode(unpack_graph)); | |||
| nodes.push_back(func_node); | |||
| // {{GradOperation, ...}, args...} | |||
| std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), | |||
| [](const AnfNodePtr& node) { return node; }); | |||
| unpack_graph_node = func_graph->NewCNode(nodes); | |||
| } | |||
| return unpack_graph_node; | |||
| } | |||
| // get metagraph of value node | |||
| MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { | |||
| ValuePtr value; | |||
| if (IsValueNode<prim::DoSignaturePrimitive>(node)) { | |||
| value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function(); | |||
| } else { | |||
| value = GetValueNode(node); | |||
| } | |||
| if (value == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return value->cast<MetaFuncGraphPtr>(); | |||
| } | |||
| // check if node is a specific metafuncgraph op | |||
| bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) { | |||
| if (node != nullptr) { | |||
| auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); | |||
| if (meta_func_graph_ptr == nullptr) { | |||
| return false; | |||
| } | |||
| if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| // {{GradOperation, g, w}, Ys} | |||
| // {UnPackCall, {GradOperation, g, w}, Ys} | |||
| AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) { | |||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| // {{...}, Ys} | |||
| auto inputs_y = node->cast<CNodePtr>()->inputs(); | |||
| std::vector<AnfNodePtr> inputs_x; | |||
| if (IsCNode(inputs_y[0])) { | |||
| inputs_x = inputs_y[0]->cast<CNodePtr>()->inputs(); | |||
| } else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) { | |||
| inputs_x = inputs_y[1]->cast<CNodePtr>()->inputs(); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| // {{...}, Xs} | |||
| if (inputs_x.size() < 2) { | |||
| return nullptr; | |||
| } | |||
| // {GradOperation, g, w} or {GradOperation, g} | |||
| if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) { | |||
| return nullptr; | |||
| } | |||
| auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]); | |||
| if (meta_func == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>(); | |||
| auto func_node = inputs_x[1]; | |||
| if (!IsValueNode<FuncGraph>(func_node)) { | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr unpack_graph_node = | |||
| GenerateUnpackGraphNode(inputs_y, node->cast<CNodePtr>()->func_graph(), func_node, | |||
| IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param()); | |||
| // constuct new grad_opration | |||
| inputs_x[1] = unpack_graph_node; | |||
| auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x); | |||
| if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) { | |||
| inputs_y[1] = grad_op_cnode; | |||
| } else { | |||
| inputs_y[0] = grad_op_cnode; | |||
| } | |||
| auto cnode = node->func_graph()->NewCNode(inputs_y); | |||
| return cnode; | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "operator/composite/composite.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| // {{GradOperation, g, w}, Ys} | |||
| // {UnPackCall, {GradOperation, g, w}, Ys} | |||
| class GradVarPrepare : public AnfVisitor { | |||
| public: | |||
| GradVarPrepare() | |||
| : grad_op_(std::make_shared<prim::GradOperation>("grad")), | |||
| unpack_op_(std::make_shared<prim::UnpackCall>("unpack_call")) {} | |||
| ~GradVarPrepare() override = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||
| private: | |||
| MetaFuncGraphPtr grad_op_; | |||
| MetaFuncGraphPtr unpack_op_; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ | |||
| @@ -175,10 +175,10 @@ bool CombineLikeGraphs(const ResourcePtr&) { | |||
| bool SymbolResolveAction(const ResourcePtr& res) { | |||
| if (res->manager() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Resolve error."; | |||
| MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; | |||
| } | |||
| if (res->func_graph() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Resolve error"; | |||
| MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null"; | |||
| } | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| auto succ = parse::ResolveFuncGraph(func_graph, res); | |||
| @@ -194,6 +194,16 @@ bool SymbolResolveAction(const ResourcePtr& res) { | |||
| return succ; | |||
| } | |||
| bool InferenceOptPrepareAction(const ResourcePtr& res) { | |||
| if (res->manager() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; | |||
| } | |||
| if (res->func_graph() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null."; | |||
| } | |||
| return InferenceOptPreparePass(res); | |||
| } | |||
| bool AbstractSpecializeAction(const ResourcePtr& res) { | |||
| if (res->func_graph() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "AbstractSpecialize error"; | |||
| @@ -331,7 +341,7 @@ static std::vector<ActionItem> CommonPipeline() { | |||
| // Resolve the python func | |||
| actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); | |||
| actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | |||
| actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | |||
| // Evaluate type and shape, and specialize | |||
| actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | |||
| @@ -160,6 +160,13 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib& irpass) { | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetInferenceOptPreparePhases() { | |||
| opt::irpass::InferenceOptPrepareLib irpass; | |||
| auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); | |||
| opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}}); | |||
| return prepare_map; | |||
| } | |||
| OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { | |||
| opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); | |||
| OptPassGroupMap map({{"prepare_group", prepare_group}}); | |||
| @@ -239,6 +246,16 @@ bool ValidatePass(const ResourcePtr& res) { | |||
| return true; | |||
| } | |||
| bool InferenceOptPreparePass(const ResourcePtr& res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| abstract::AbstractBasePtrList args_spec = res->args_spec(); | |||
| auto prepare_map = GetInferenceOptPreparePhases(); | |||
| auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); | |||
| (void)infer_opt_prepare->step(func_graph, args_spec, false); | |||
| return true; | |||
| } | |||
| std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_a", OptPassAGroup}, | |||
| {"opt_b", OptPassBGroup}, | |||
| @@ -34,7 +34,7 @@ bool CconvPass(const ResourcePtr& res); | |||
| bool ValidatePass(const ResourcePtr& res); | |||
| bool ConvertPrepareAdapt(const ResourcePtr& res); | |||
| bool AddControlDependPass(const ResourcePtr& res); | |||
| bool InferenceOptPreparePass(const ResourcePtr& res); | |||
| void ReclaimOptimizer(); | |||
| } // namespace pipeline | |||
| } // namespace mindspore | |||
| @@ -133,6 +133,7 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| FuncGraphPtr func_graph_; | |||
| AnalysisContextPtr context_; | |||
| }; | |||
| using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>; | |||
| class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| public: | |||
| @@ -41,7 +41,7 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func | |||
| } else { | |||
| oss << "nullptr"; | |||
| } | |||
| MS_LOG(EXCEPTION) << "" << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| return NewContext(parent_context, func_graph, args_spec_list); | |||
| } | |||
| @@ -180,6 +180,85 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config | |||
| return engine->ForwardConfig(out_conf, fn_conf); | |||
| } | |||
| static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) { | |||
| // arg[0] is the func graph to unpack, ignore it | |||
| AbstractBasePtrList sepcialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end()); | |||
| AbstractBasePtrList graph_sepcialize_args; | |||
| if (need_unpack) { | |||
| for (size_t index = 0; index < sepcialize_args_before_unpack.size(); index++) { | |||
| MS_EXCEPTION_IF_NULL(sepcialize_args_before_unpack[index]); | |||
| if (sepcialize_args_before_unpack[index]->isa<AbstractTuple>()) { | |||
| AbstractTuplePtr arg_tuple = sepcialize_args_before_unpack[index]->cast<AbstractTuplePtr>(); | |||
| std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), | |||
| std::back_inserter(graph_sepcialize_args), [](AbstractBasePtr abs) { return abs; }); | |||
| } else if (sepcialize_args_before_unpack[index]->isa<AbstractDictionary>()) { | |||
| AbstractDictionaryPtr arg_dict = sepcialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>(); | |||
| auto dict_elems = arg_dict->elements(); | |||
| (void)std::transform( | |||
| dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_sepcialize_args), | |||
| [](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); }); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got " | |||
| << sepcialize_args_before_unpack[index]->ToString(); | |||
| } | |||
| } | |||
| } else { | |||
| graph_sepcialize_args = sepcialize_args_before_unpack; | |||
| } | |||
| return graph_sepcialize_args; | |||
| } | |||
| AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | |||
| } | |||
| if (!prim_->isa<prim::UnpackGraphPrimitive>()) { | |||
| MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString(); | |||
| } | |||
| auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>(); | |||
| auto out_node = out_conf->node()->cast<CNodePtr>(); | |||
| const auto &out_node_inputs = out_node->inputs(); | |||
| if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { | |||
| MS_LOG(EXCEPTION) << "UnpackGraphPrimitive" | |||
| << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() | |||
| << ", inputs size " << out_node_inputs.size(); | |||
| } | |||
| AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); | |||
| // get the forward graph | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>(); | |||
| if (fn == nullptr) { | |||
| MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); | |||
| } | |||
| auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>(); | |||
| MS_EXCEPTION_IF_NULL(real_fn); | |||
| FuncGraphPtr forward_graph = real_fn->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(forward_graph); | |||
| AbstractBasePtrList graph_sepcialize_args = | |||
| GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args()); | |||
| AbstractBasePtrList graph_sepcialize_args_without_sens; | |||
| (void)std::transform(graph_sepcialize_args.begin(), | |||
| graph_sepcialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0), | |||
| std::back_inserter(graph_sepcialize_args_without_sens), [](AbstractBasePtr abs) { return abs; }); | |||
| auto new_graph = forward_graph->GenerateGraph(graph_sepcialize_args_without_sens); | |||
| engine->func_graph_manager()->AddFuncGraph(new_graph); | |||
| ScopePtr scope = kDefaultScope; | |||
| if (out_conf != nullptr) { | |||
| scope = out_conf->node()->scope(); | |||
| } | |||
| ScopeGuard scope_guard(scope); | |||
| AnfNodePtr new_vnode = NewValueNode(new_graph); | |||
| AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context()); | |||
| return engine->ForwardConfig(out_conf, fn_conf); | |||
| } | |||
| namespace { | |||
| py::object BuildValue(const ValuePtr &value_ptr) { | |||
| if (value_ptr == nullptr) { | |||
| @@ -87,6 +87,21 @@ class DoSignatureEvaluator : public Evaluator { | |||
| PrimitivePtr prim_; | |||
| }; | |||
| class UnpackGraphEvaluator : public Evaluator { | |||
| public: | |||
| explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} | |||
| ~UnpackGraphEvaluator() override = default; | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||
| AnfNodeConfigPtr out_config = nullptr) override; | |||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; | |||
| } | |||
| private: | |||
| PrimitivePtr prim_; | |||
| }; | |||
| bool IsInWhiteList(PrimitivePtr primitive); | |||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | |||
| @@ -289,6 +289,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| evaluator = std::make_shared<DoSignatureEvaluator>(prim); | |||
| return evaluator; | |||
| } | |||
| if (prim->isa<prim::UnpackGraphPrimitive>()) { | |||
| evaluator = std::make_shared<UnpackGraphEvaluator>(prim); | |||
| return evaluator; | |||
| } | |||
| if (prim->HasPyEvaluator()) { | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | |||
| if (prim_py != nullptr) { | |||
| @@ -19,6 +19,8 @@ from mindspore.nn import Cell | |||
| from mindspore.ops import operations as P | |||
| import mindspore.ops.composite as C | |||
| from mindspore.common.api import _executor | |||
| from mindspore.common.parameter import ParameterTuple | |||
| from mindspore.common import dtype as mstype | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -34,3 +36,152 @@ def test_net_vargs_expand(): | |||
| sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) | |||
| net = AddNet() | |||
| out = C.grad_all_with_sens(net, net.trainable_params())(x, y, sens) | |||
| class VarNet(Cell): | |||
| def __init__(self, net): | |||
| super(VarNet, self).__init__() | |||
| self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True) | |||
| self.w = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "w", requires_grad=True) | |||
| self.net = net | |||
| def construct(self, *args): | |||
| return self.net(*args)*self.w + self.b | |||
| class SecondNet(Cell): | |||
| def __init__(self): | |||
| super(SecondNet, self).__init__() | |||
| self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True) | |||
| def construct(self, *args): | |||
| res = args[0] + args[1] | |||
| return res + self.b2 | |||
| def test_all_var_args_grad_with_sens(): | |||
| """"test grad_by_list_with_sens with all var args input""" | |||
| class GradNet(Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return C.grad_by_list_with_sens(self.net, self.weights)(*inputs) | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| sens = Tensor(1.0, dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| out = grad_net(x, y, sens) | |||
| def test_grad_list_var_args(): | |||
| class GradNet(Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return C.grad_by_list(self.net, self.weights)(*inputs) | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| out = grad_net(x, y) | |||
| def test_grad_all_var_args(): | |||
| class GradNet(Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return C.grad_all(self.net)(*inputs) | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| out = grad_net(x, y) | |||
| def test_grad_all_var_args_with_sens(): | |||
| class GradNet(Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return C.grad_all_with_sens(self.net)(*inputs) | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| sens = Tensor(1.0, dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| out = grad_net(x, y, sens) | |||
| def test_grad_var_args_with_sens(): | |||
| class GradNet(Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return C.grad_with_sens(self.net)(*inputs) | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| sens = Tensor(1.0, dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| out = grad_net(x, y, sens) | |||
| def test_var_args_grad(): | |||
| class VarNet(Cell): | |||
| def __init__(self, net): | |||
| super(VarNet, self).__init__() | |||
| self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True) | |||
| self.net = net | |||
| def construct(self, *args): | |||
| return self.net(*args) + self.b | |||
| class SecondNet(Cell): | |||
| def __init__(self): | |||
| super(SecondNet, self).__init__() | |||
| self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True) | |||
| def construct(self, *args): | |||
| res = args[0] + args[1] | |||
| return res + self.b2 | |||
| class GradNet(Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, x, y, sens): | |||
| return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens) | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| sens = Tensor(1.0, dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| out = grad_net(x, y, sens) | |||
| def test_var_args_positional(): | |||
| """"test grad_all with var args in inner graph""" | |||
| class VarNet(Cell): | |||
| def __init__(self, net): | |||
| super(VarNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, x, y): | |||
| return self.net(x, y)*x | |||
| class SecondNet(Cell): | |||
| def __init__(self): | |||
| super(SecondNet, self).__init__() | |||
| def construct(self, *args): | |||
| return args[0] + args[1] | |||
| class GradNet(Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, x, y): | |||
| return C.grad_all(self.net)(x, y) | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| out = grad_net(x, y) | |||